In [1]:
# run "pip installipython-autotime" in your conda env
%load_ext autotime

print('hello world')
hello world
time: 242 μs (started: 2025-11-03 21:13:18 +01:00)
In [2]:
# Import bayesDREAM with reload capability
import importlib

import sys
import os
import torch
from pathlib import Path
# Add the directory containing 'bayesDREAM' to sys.path
base_path = Path('/cfs/klemming/projects/snic/lappalainen_lab1/users/Leah/CRISPRmodelling/BayesianModelling/bayesDREAM_forClaude/bayesDREAM code/bayesDREAM_forClaude/')  # Adjust relative path from your notebook
sys.path.append(str(base_path))
# Now import the model
import bayesDREAM

# Reload only if the module was already imported
importlib.reload(bayesDREAM)
from bayesDREAM import bayesDREAM
time: 1min 57s (started: 2025-11-03 21:13:18 +01:00)
In [3]:
import pandas as pd
time: 167 μs (started: 2025-11-03 21:15:15 +01:00)
In [4]:
deviceno = 0
device = torch.device(f'cuda:{deviceno}' if torch.cuda.is_available() else 'cpu')
time: 53.5 ms (started: 2025-11-03 21:15:15 +01:00)

Load data¶

In [5]:
data_dir = '/cfs/klemming/projects/snic/lappalainen_lab1/users/Leah/data/unpublished/Josie_pilot/processed/fromJosie_Oct2025/SCE_from_Josie/10X_SR/'

cell_meta = pd.read_csv(data_dir + '10X_SR_cell_meta.csv')
gene_meta = pd.read_csv(data_dir + '10X_SR_gene_meta.csv')
gene_counts = pd.read_csv(data_dir + '10X_SR_counts.csv', index_col=None)
gene_counts.index = gene_meta['Symbol'].values
gene_meta.index = gene_meta['Symbol'].values
gene_meta = gene_meta.rename(columns={'ID': 'gene_id', 'Symbol': 'gene_name'})
cell_meta['cell'] = cell_meta['Barcode']
cell_meta.loc[cell_meta['target'] == 'NTC', 'target'] = 'ntc'
time: 2.56 s (started: 2025-11-03 21:15:15 +01:00)

Create bayesDREAM objects¶

In [6]:
model = {}
for cg in ['GFI1B', 'GEMIN5', 'DDX6']:
    model[cg] = bayesDREAM(
        meta=cell_meta,
        counts=gene_counts,
        gene_meta=gene_meta,
        cis_gene=cg,
        output_dir=data_dir+'bayesDREAM/output/',
        sum_factor_col = 'clustered.sum.factor',
        label='20251030_' + cg,
        device = device
    )
[INFO] Extracting 'cis' modality from gene 'GFI1B'
[INFO] Creating 'gene' modality with trans genes (excluding 'GFI1B')
[VALIDATION] Primary modality 'gene' is 'negbinom' - cis modeling is valid
[INFO] Using 'gene_name' column as 'gene' identifier
[INFO] Gene metadata loaded with 13792 genes and columns: ['gene', 'gene_name', 'gene_id']
/cfs/klemming/projects/snic/lappalainen_lab1/users/Leah/CRISPRmodelling/BayesianModelling/bayesDREAM_forClaude/bayesDREAM code/bayesDREAM_forClaude/bayesDREAM/core.py:220: UserWarning: Subsetting reduced the number of cells in the metadata from 2600 to 1884. This may impact downstream analysis.
  warnings.warn(
[INIT] bayesDREAM core: label=20251030_GFI1B, device=cpu
[INFO] Subsetting modalities to 1884 cells from filtered meta
[INFO] Subsetting modality 'cis' from 2600 to 1884 cells
[INFO] Subsetting modality 'gene' from 2600 to 1884 cells
[INIT] bayesDREAM: 2 modalities loaded
  - cis: Modality(name='cis', distribution='negbinom', dims={'n_features': 1, 'n_cells': 1884})
  - gene: Modality(name='gene', distribution='negbinom', dims={'n_features': 13791, 'n_cells': 1884})
[INFO] Extracting 'cis' modality from gene 'GEMIN5'
[INFO] Creating 'gene' modality with trans genes (excluding 'GEMIN5')
[VALIDATION] Primary modality 'gene' is 'negbinom' - cis modeling is valid
[INFO] Using 'gene_name' column as 'gene' identifier
[INFO] Gene metadata loaded with 13792 genes and columns: ['gene', 'gene_name', 'gene_id']
[INIT] bayesDREAM core: label=20251030_GEMIN5, device=cpu
[INFO] Subsetting modalities to 1879 cells from filtered meta
/cfs/klemming/projects/snic/lappalainen_lab1/users/Leah/CRISPRmodelling/BayesianModelling/bayesDREAM_forClaude/bayesDREAM code/bayesDREAM_forClaude/bayesDREAM/core.py:220: UserWarning: Subsetting reduced the number of cells in the metadata from 2600 to 1879. This may impact downstream analysis.
  warnings.warn(
[INFO] Subsetting modality 'cis' from 2600 to 1879 cells
[INFO] Subsetting modality 'gene' from 2600 to 1879 cells
[INIT] bayesDREAM: 2 modalities loaded
  - cis: Modality(name='cis', distribution='negbinom', dims={'n_features': 1, 'n_cells': 1879})
  - gene: Modality(name='gene', distribution='negbinom', dims={'n_features': 13791, 'n_cells': 1879})
[INFO] Extracting 'cis' modality from gene 'DDX6'
[INFO] Creating 'gene' modality with trans genes (excluding 'DDX6')
[VALIDATION] Primary modality 'gene' is 'negbinom' - cis modeling is valid
[INFO] Using 'gene_name' column as 'gene' identifier
[INFO] Gene metadata loaded with 13792 genes and columns: ['gene', 'gene_name', 'gene_id']
[INIT] bayesDREAM core: label=20251030_DDX6, device=cpu
[INFO] Subsetting modalities to 1819 cells from filtered meta
/cfs/klemming/projects/snic/lappalainen_lab1/users/Leah/CRISPRmodelling/BayesianModelling/bayesDREAM_forClaude/bayesDREAM code/bayesDREAM_forClaude/bayesDREAM/core.py:220: UserWarning: Subsetting reduced the number of cells in the metadata from 2600 to 1819. This may impact downstream analysis.
  warnings.warn(
[INFO] Subsetting modality 'cis' from 2600 to 1819 cells
[INFO] Subsetting modality 'gene' from 2600 to 1819 cells
[INIT] bayesDREAM: 2 modalities loaded
  - cis: Modality(name='cis', distribution='negbinom', dims={'n_features': 1, 'n_cells': 1819})
  - gene: Modality(name='gene', distribution='negbinom', dims={'n_features': 13791, 'n_cells': 1819})
time: 1.82 s (started: 2025-11-03 21:15:18 +01:00)
In [7]:
for cg in ['GFI1B', 'GEMIN5', 'DDX6']:
    model[cg].adjust_ntc_sum_factor(sum_factor_col_old='clustered.sum.factor')
[INFO] Created 'sum_factor_adj' in meta with NTC-based guide-level adjustment.
[INFO] Created 'sum_factor_adj' in meta with NTC-based guide-level adjustment.
[INFO] Created 'sum_factor_adj' in meta with NTC-based guide-level adjustment.
time: 108 ms (started: 2025-11-03 21:15:20 +01:00)
In [8]:
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np

cgs = ['GFI1B', 'GEMIN5', 'DDX6']

# --- Define consistent palettes ---
palette = {
    'GFI1B': [cm.Greens(i) for i in np.linspace(0.4, 0.9, 3)],   # GFI1B_[1-3]
    'NTC':   [cm.Greys(i)  for i in np.linspace(0.4, 0.8, 5)],   # NTC_[1-5]
    'GEMIN5':[cm.Blues(i)  for i in np.linspace(0.4, 0.8, 2)],   # GEMIN5_[1-2]
    'DDX6':  [cm.Reds(i)   for i in np.linspace(0.4, 0.8, 3)],   # DDX6_[1,3]
}

# Flatten into guide→color dictionary
guide_colors = {}
for gene, colors in palette.items():
    for i, color in enumerate(colors, start=1):
        guide_colors[f"{gene}_{i}"] = color

# --- Plot ---
fig, axes = plt.subplots(1, len(cgs), figsize=(5*len(cgs), 4), sharex=False, sharey=False)
if len(cgs) == 1:
    axes = [axes]

for ax, cg in zip(axes, cgs):
    df = model[cg].meta.copy()
    df = df[(df['clustered.sum.factor'] > 0) & (df['sum_factor_adj'] > 0)]

    for guide, sub in df.groupby('guide'):
        color = guide_colors.get(guide, 'black')
        ax.scatter(
            sub['clustered.sum.factor'],
            sub['sum_factor_adj'],
            s=12,
            alpha=0.8,
            color=color,
            label=guide,
        )

    ax.set_xscale('log', base=2)
    ax.set_yscale('log', base=2)
    ax.set_title(cg)
    ax.set_xlabel('clustered.sum.factor (log₂)')
    ax.set_ylabel('sum_factor_adj (log₂)')
    ax.grid(True, linewidth=0.5, alpha=0.4)
    ax.legend(title='guide', fontsize=8, markerscale=1.2, frameon=False)

plt.tight_layout()
plt.show()
No description has been provided for this image
time: 1.38 s (started: 2025-11-03 21:15:20 +01:00)

Fit cis¶

In [9]:
for cg in ['GFI1B', 'GEMIN5', 'DDX6']:
    print(cg)
    print(model[cg].list_modalities())
    print()
GFI1B
   name distribution  n_features  n_cells
0   cis     negbinom           1     1884
1  gene     negbinom       13791     1884

GEMIN5
   name distribution  n_features  n_cells
0   cis     negbinom           1     1879
1  gene     negbinom       13791     1879

DDX6
   name distribution  n_features  n_cells
0   cis     negbinom           1     1819
1  gene     negbinom       13791     1819

time: 19.2 ms (started: 2025-11-03 21:15:21 +01:00)

fit¶

In [10]:
for cg in ['GFI1B', 'GEMIN5', 'DDX6']:
    # --- CIS FIT: Load if exists, otherwise fit and save ---
    cis_fit_path = os.path.join(model[cg].output_dir, model[cg].label, 'x_true.pt')
    if os.path.exists(cis_fit_path):
        print("[INFO] Loading existing cis fit...")
        model[cg].load_cis_fit()
    else:
        print("[INFO] Running cis fit (this may take a while)...")
        model[cg].fit_cis(sum_factor_col="sum_factor_adj", tolerance=0, niters=100000)
        model[cg].save_cis_fit()
[INFO] Loading existing cis fit...
[LOAD] x_true (posterior) ← /cfs/klemming/projects/snic/lappalainen_lab1/users/Leah/data/unpublished/Josie_pilot/processed/fromJosie_Oct2025/SCE_from_Josie/10X_SR/bayesDREAM/output/20251030_GFI1B/x_true.pt
[LOAD] Cis fit loaded from /cfs/klemming/projects/snic/lappalainen_lab1/users/Leah/data/unpublished/Josie_pilot/processed/fromJosie_Oct2025/SCE_from_Josie/10X_SR/bayesDREAM/output/20251030_GFI1B
[INFO] Loading existing cis fit...
[LOAD] x_true (posterior) ← /cfs/klemming/projects/snic/lappalainen_lab1/users/Leah/data/unpublished/Josie_pilot/processed/fromJosie_Oct2025/SCE_from_Josie/10X_SR/bayesDREAM/output/20251030_GEMIN5/x_true.pt
[LOAD] Cis fit loaded from /cfs/klemming/projects/snic/lappalainen_lab1/users/Leah/data/unpublished/Josie_pilot/processed/fromJosie_Oct2025/SCE_from_Josie/10X_SR/bayesDREAM/output/20251030_GEMIN5
[INFO] Loading existing cis fit...
[LOAD] x_true (posterior) ← /cfs/klemming/projects/snic/lappalainen_lab1/users/Leah/data/unpublished/Josie_pilot/processed/fromJosie_Oct2025/SCE_from_Josie/10X_SR/bayesDREAM/output/20251030_DDX6/x_true.pt
[LOAD] Cis fit loaded from /cfs/klemming/projects/snic/lappalainen_lab1/users/Leah/data/unpublished/Josie_pilot/processed/fromJosie_Oct2025/SCE_from_Josie/10X_SR/bayesDREAM/output/20251030_DDX6
time: 772 ms (started: 2025-11-03 21:15:21 +01:00)

Plot results¶

In [11]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from scipy.stats import gaussian_kde

# ----------------------------
# Palette and color utilities
# ----------------------------
palette = {
    'GFI1B': [cm.Greens(i) for i in np.linspace(0.4, 0.9, 3)],   # GFI1B_[1-3]
    'NTC':   [cm.Greys(i)  for i in np.linspace(0.4, 0.8, 5)],   # NTC_[1-5]
    'GEMIN5':[cm.Blues(i)  for i in np.linspace(0.4, 0.8, 2)],   # GEMIN5_[1-2]
    'DDX6':  [cm.Reds(i)   for i in np.linspace(0.4, 0.8, 3)],   # DDX6_[1,3]
}

def build_guide_colors(palette_dict):
    guide_colors = {}
    for gene, colors in palette_dict.items():
        for i, color in enumerate(colors, start=1):
            guide_colors[f"{gene}_{i}"] = color
    return guide_colors

guide_colors = build_guide_colors(palette)

# Optional: target color map (violin plots)
target_colors = {
    'GFI1B': cm.Greens(0.7),
    'NTC':   cm.Greys(0.6),
    'GEMIN5':cm.Blues(0.7),
    'DDX6':  cm.Reds(0.7),
    'ntc':   cm.Greys(0.6),  # if you've normalized NTC to lowercase
}

# ----------------------------
# Helpers
# ----------------------------
def to_np(a):
    """Safely convert torch/array-like to numpy."""
    try:
        import torch
        if isinstance(a, torch.Tensor):
            return a.detach().cpu().numpy()
    except Exception:
        pass
    return np.asarray(a)

def per_cell_mean_std(x):
    """Compute per-cell mean and std along axis 0 (samples x cells)."""
    x_np = to_np(x)
    return x_np.mean(axis=0), x_np.std(axis=0)

# ----------------------------
# 1) Scatter plots colored by guide
# ----------------------------
def scatter_by_guide(model, cg, log2=False):
    df = model[cg].meta.copy()
    X = to_np(model[cg].x_true)

    if log2:
        # filter strictly positive before log
        mask_pos = (X > 0).all(axis=0)
        X = X[:, mask_pos]
        df = df.loc[mask_pos].reset_index(drop=True)
        X = np.log2(X)

    x_mean, x_std = X.mean(axis=0), X.std(axis=0)

    plt.figure(figsize=(6, 5))
    for guide, subidx in df.groupby('guide').groups.items():
        color = guide_colors.get(guide, 'black')
        plt.scatter(x_mean[subidx], x_std[subidx], s=14, alpha=0.8, color=color, label=guide)

    plt.xlabel('mean x_true' + (' (log2)' if log2 else ''))
    plt.ylabel('std x_true' + (' (log2)' if log2 else ''))
    plt.title(f'{cg}: mean vs std of x_true' + (' (log2)' if log2 else ''))
    plt.grid(True, linewidth=0.5, alpha=0.4)
    plt.legend(title='guide', fontsize=8, markerscale=1.2, frameon=False)
    plt.tight_layout()
    plt.show()

# Usage (your two plots):
# raw scale
# scatter_by_guide(model, cg, log2=False)
# log2 scale
# scatter_by_guide(model, cg, log2=True)

def scatter_ci95_by_guide(model, cg, log2=False, full_width=False):
    """
    Scatter of per-cell mean vs 95% CI width (or half-width) of x_true samples.
      - x: mean over samples
      - y: CI_95 width (q97.5 - q2.5) if full_width=True,
           else half-width = 0.5 * (q97.5 - q2.5)
    Colors points by model[cg].meta['guide'] using guide_colors.
    """
    df = model[cg].meta.copy()
    X = to_np(model[cg].x_true)  # shape [S, N] (samples x cells)

    if log2:
        # Keep only cells strictly positive across samples before log2
        mask_pos = (X > 0).all(axis=0)
        X = X[:, mask_pos]
        df = df.loc[mask_pos].reset_index(drop=True)
        X = np.log2(X)

    x_mean = X.mean(axis=0)
    q_lo  = np.percentile(X, 2.5, axis=0)
    q_hi  = np.percentile(X, 97.5, axis=0)
    ci_w  = (q_hi - q_lo)
    y_val = ci_w if full_width else 0.5 * ci_w  # half-width by default

    plt.figure(figsize=(6, 5))
    for guide, idx in df.groupby('guide').groups.items():
        color = guide_colors.get(guide, 'black')
        plt.scatter(x_mean[idx], y_val[idx], s=14, alpha=0.85, color=color, label=guide)

    plt.xlabel('mean x_true' + (' (log2)' if log2 else ''))
    ylabel = '95% CI ' + ('width' if full_width else 'half-width')
    ylabel += ' of x_true' + (' (log2)' if log2 else '')
    plt.ylabel(ylabel)
    plt.title(f'{cg}: mean vs 95% CI of x_true' + (' (log2)' if log2 else ''))
    plt.grid(True, linewidth=0.5, alpha=0.4)
    plt.legend(title='guide', fontsize=8, markerscale=1.2, frameon=False)
    plt.tight_layout()
    plt.show()


# ----------------------------------------------------------
# 2) Violin: x-axis = guide, color = target, x_true on log2
# ----------------------------------------------------------
def violin_by_guide_log2(model, cg):
    df = model[cg].meta.copy()
    X = to_np(model[cg].x_true)

    # Keep only cells with strictly positive across samples before log2
    pos_mask = (X > 0).all(axis=0)
    X = X[:, pos_mask]
    df = df.loc[pos_mask].reset_index(drop=True)

    Xlog = np.log2(X)
    x_cell_mean = Xlog.mean(axis=0)
    df = df.assign(x_true_mean_log2=x_cell_mean)

    # Order guides nicely
    guide_order = sorted(df['guide'].astype(str).unique(),
                         key=lambda g: (g.split('_')[0], int(g.split('_')[1]) if '_' in g and g.split('_')[1].isdigit() else 0))
    data = [df.loc[df['guide'] == g, 'x_true_mean_log2'].values for g in guide_order]

    # Color each violin by its guide's target
    colors = []
    for g in guide_order:
        tvals = df.loc[df['guide'] == g, 'target'].astype(str).unique()
        t = tvals[0] if len(tvals) else 'NTC'
        colors.append(target_colors.get(t, 'gray'))

    plt.figure(figsize=(max(6, 1.2*len(guide_order)), 4.8))
    parts = plt.violinplot(data, showmeans=True, showextrema=False)

    for body, c in zip(parts['bodies'], colors):
        body.set_facecolor(c)
        body.set_edgecolor('black')
        body.set_alpha(0.85)
    # mean line style
    if 'cmeans' in parts:
        parts['cmeans'].set_color('black')
        parts['cmeans'].set_linewidth(1.0)

    plt.xticks(ticks=np.arange(1, len(guide_order)+1), labels=guide_order, rotation=45, ha='right')
    plt.xlabel('guide')
    plt.ylabel('x_true mean (log₂)')
    plt.title(f'{cg}: x_true (log₂) by guide (colored by target)')
    plt.grid(True, linewidth=0.5, alpha=0.4, axis='y')
    plt.tight_layout()
    plt.show()

# ----------------------------------------------------------
# 3) Density (KDE): filled, color = guide, x_true on log2
#    KDE over per-cell mean of x_true (log2)
# ----------------------------------------------------------
def filled_density_by_guide_log2(model, cg, bw=None):
    df = model[cg].meta.copy()
    X = to_np(model[cg].x_true)

    pos_mask = (X > 0).all(axis=0)
    X = X[:, pos_mask]
    df = df.loc[pos_mask].reset_index(drop=True)

    Xlog = np.log2(X)
    x_cell_mean = Xlog.mean(axis=0)
    df = df.assign(x_true_mean_log2=x_cell_mean)

    # global x-range for all guides
    xmin, xmax = np.percentile(x_cell_mean, [0.5, 99.5])
    xs = np.linspace(xmin, xmax, 400)

    plt.figure(figsize=(7, 4.8))

    # Keep legend order stable
    guides = sorted(df['guide'].astype(str).unique(),
                    key=lambda g: (g.split('_')[0], int(g.split('_')[1]) if '_' in g and g.split('_')[1].isdigit() else 0))

    for g in guides:
        vals = df.loc[df['guide'] == g, 'x_true_mean_log2'].values
        color = guide_colors.get(g, 'black')
        if len(np.unique(vals)) < 2:
            # Not enough variance for KDE; plot a small filled bump
            y = np.exp(-0.5*((xs - vals[0]) / 0.01)**2)  # tiny Gaussian bump
            plt.fill_between(xs, 0*y, y, color=color, alpha=0.35, label=g)
            continue
        kde = gaussian_kde(vals, bw_method=bw)
        ys = kde(xs)
        plt.fill_between(xs, 0, ys, color=color, alpha=0.35, label=g)
        plt.plot(xs, ys, color=color, linewidth=1.5)

    plt.xlabel('x_true mean (log₂)')
    plt.ylabel('density')
    plt.title(f'{cg}: filled density by guide (log₂)')
    plt.grid(True, linewidth=0.5, alpha=0.4)
    plt.legend(title='guide', fontsize=8, frameon=False, ncol=2)
    plt.tight_layout()
    plt.show()
time: 5.68 ms (started: 2025-11-03 21:15:22 +01:00)
In [12]:
# -----------------------------------
# Example calls for your three genes
# -----------------------------------
cgs = ['GFI1B', 'GEMIN5', 'DDX6']
for cg in cgs:
    # Your two scatter plots:
    scatter_by_guide(model, cg, log2=False)
    scatter_by_guide(model, cg, log2=True)
    
    # Your two scatter plots:
    scatter_ci95_by_guide(model, cg, log2=False, full_width=True)
    scatter_ci95_by_guide(model, cg, log2=True, full_width=True)

    # Violin by target:
    violin_by_guide_log2(model, cg)   # or log2=True if you prefer

    # Density by guide:
    filled_density_by_guide_log2(model, cg)   # or log2=True for log2 scale
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
time: 4.26 s (started: 2025-11-03 21:15:22 +01:00)
In [13]:
cgs = ['GFI1B', 'GEMIN5', 'DDX6']
for cg in cgs:
    model[cg].set_technical_groups(['Sample'])
[INFO] Set technical_group_code with 1 groups based on ['Sample']
[INFO] Set technical_group_code with 1 groups based on ['Sample']
[INFO] Set technical_group_code with 1 groups based on ['Sample']
time: 5.06 ms (started: 2025-11-03 21:15:26 +01:00)
In [14]:
cgs = ['GFI1B', 'GEMIN5', 'DDX6']
tgs = ['MYB', 'HES4', 'GAPDH']
for tg in tgs:
    for cg in cgs:
        model[cg].plot_xy_data(tg, window=100, sum_factor_col='sum_factor_adj', show_correction='uncorrected');
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
time: 2.04 s (started: 2025-11-03 21:15:26 +01:00)
In [15]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# --- Prepare data ---
counts = model['GEMIN5'].counts
meta   = model['GEMIN5'].meta.copy()

# Ensure meta is indexed by cell IDs so it can align to counts.columns
if 'cell' in meta.columns and not meta.index.equals(counts.columns):
    meta = meta.set_index('cell')

# Build an aligned mask for the counts' columns
mask = meta['target'].reindex(counts.columns).eq('ntc').fillna(False)

# Extract counts for GEMIN5 in NTC cells
gemin5_ntc = counts.loc['GEMIN5', mask]

# --- Plot ---
plt.figure(figsize=(6, 4))
counts_hist, bins, patches = plt.hist(
    gemin5_ntc,
    bins=30,
    color="#4C72B0",
    edgecolor="black",
    alpha=0.8
)

# --- Add percentage labels above each bar ---
total = counts_hist.sum()
for count, bin_left, bin_right in zip(counts_hist, bins[:-1], bins[1:]):
    if count > 0:
        percent = 100 * count / total
        plt.text(
            (bin_left + bin_right) / 2,
            count,
            f"{percent:.1f}%",
            ha="center",
            va="bottom",
            fontsize=8,
            rotation=0
        )

# --- Styling ---
plt.title("GEMIN5 Expression in NTC Cells", fontsize=14, weight="bold", pad=15)
plt.xlabel("Counts", fontsize=12)
plt.ylabel("Number of Cells", fontsize=12)
plt.grid(True, linestyle="--", linewidth=0.5, alpha=0.5)
sns.despine(trim=True)
plt.tight_layout()
plt.show()
No description has been provided for this image
time: 131 ms (started: 2025-11-03 21:15:28 +01:00)
In [16]:
np.mean(gemin5_ntc > 0)
Out[16]:
np.float64(0.39168343393695504)
time: 1.48 ms (started: 2025-11-03 21:15:29 +01:00)

Fit trans¶

In [17]:
# remove hacky tech group added
for cg in ['GFI1B', 'GEMIN5', 'DDX6']:
    model[cg].meta.drop('technical_group_code', axis=1, inplace=True)
time: 6.82 ms (started: 2025-11-03 21:15:29 +01:00)

Refit sumfactor¶

In [18]:
for cg in ['GFI1B', 'GEMIN5', 'DDX6']:
    model[cg].refit_sumfactor(sum_factor_col_old='clustered.sum.factor')
[INFO] Created 'sum_factor_new' in meta with xtrue-based adjustment.
[INFO] Created 'sum_factor_new' in meta with xtrue-based adjustment.
[INFO] Created 'sum_factor_new' in meta with xtrue-based adjustment.
time: 731 ms (started: 2025-11-03 21:15:29 +01:00)
In [19]:
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np

cgs = ['GFI1B', 'GEMIN5', 'DDX6']

# --- Define consistent palettes ---
palette = {
    'GFI1B': [cm.Greens(i) for i in np.linspace(0.4, 0.9, 3)],   # GFI1B_[1-3]
    'NTC':   [cm.Greys(i)  for i in np.linspace(0.4, 0.8, 5)],   # NTC_[1-5]
    'GEMIN5':[cm.Blues(i)  for i in np.linspace(0.4, 0.8, 2)],   # GEMIN5_[1-2]
    'DDX6':  [cm.Reds(i)   for i in np.linspace(0.4, 0.8, 3)],   # DDX6_[1,3]
}

# Flatten into guide→color dictionary
guide_colors = {}
for gene, colors in palette.items():
    for i, color in enumerate(colors, start=1):
        guide_colors[f"{gene}_{i}"] = color

# --- Plot ---
fig, axes = plt.subplots(1, len(cgs), figsize=(5*len(cgs), 4), sharex=False, sharey=False)
if len(cgs) == 1:
    axes = [axes]

for ax, cg in zip(axes, cgs):
    df = model[cg].meta.copy()
    df = df[(df['clustered.sum.factor'] > 0) & (df['sum_factor_new'] > 0)]

    for guide, sub in df.groupby('guide'):
        color = guide_colors.get(guide, 'black')
        ax.scatter(
            sub['clustered.sum.factor'],
            sub['sum_factor_new'],
            s=12,
            alpha=0.8,
            color=color,
            label=guide,
        )

    ax.set_xscale('log', base=2)
    ax.set_yscale('log', base=2)
    ax.set_title(cg)
    ax.set_xlabel('clustered.sum.factor (log₂)')
    ax.set_ylabel('sum_factor_new (log₂)')
    ax.grid(True, linewidth=0.5, alpha=0.4)
    ax.legend(title='guide', fontsize=8, markerscale=1.2, frameon=False)

plt.tight_layout()
plt.show()
No description has been provided for this image
time: 678 ms (started: 2025-11-03 21:15:29 +01:00)
In [20]:
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np

cgs = ['GFI1B', 'GEMIN5', 'DDX6']

# --- Define consistent palettes ---
palette = {
    'GFI1B': [cm.Greens(i) for i in np.linspace(0.4, 0.9, 3)],   # GFI1B_[1-3]
    'NTC':   [cm.Greys(i)  for i in np.linspace(0.4, 0.8, 5)],   # NTC_[1-5]
    'GEMIN5':[cm.Blues(i)  for i in np.linspace(0.4, 0.8, 2)],   # GEMIN5_[1-2]
    'DDX6':  [cm.Reds(i)   for i in np.linspace(0.4, 0.8, 3)],   # DDX6_[1,3]
}

# Flatten into guide→color dictionary
guide_colors = {}
for gene, colors in palette.items():
    for i, color in enumerate(colors, start=1):
        guide_colors[f"{gene}_{i}"] = color

# --- Plot ---
fig, axes = plt.subplots(1, len(cgs), figsize=(5*len(cgs), 4), sharex=False, sharey=False)
if len(cgs) == 1:
    axes = [axes]

for ax, cg in zip(axes, cgs):
    df = model[cg].meta.copy()
    df = df[(df['sum_factor_adj'] > 0) & (df['sum_factor_new'] > 0)]

    for guide, sub in df.groupby('guide'):
        color = guide_colors.get(guide, 'black')
        ax.scatter(
            sub['sum_factor_adj'],
            sub['sum_factor_new'],
            s=12,
            alpha=0.8,
            color=color,
            label=guide,
        )

    ax.set_xscale('log', base=2)
    ax.set_yscale('log', base=2)
    ax.set_title(cg)
    ax.set_xlabel('sum_factor_adj (log₂)')
    ax.set_ylabel('sum_factor_new (log₂)')
    ax.grid(True, linewidth=0.5, alpha=0.4)
    ax.legend(title='guide', fontsize=8, markerscale=1.2, frameon=False)

plt.tight_layout()
plt.show()
No description has been provided for this image
time: 814 ms (started: 2025-11-03 21:15:30 +01:00)

Fit¶

In [21]:
for cg in ['GFI1B', 'GEMIN5', 'DDX6']:
    # --- TRANS FIT: Load if exists, otherwise fit and save ---
    trans_fit_path = os.path.join(model[cg].output_dir, model[cg].label, 'posterior_samples_trans_gene.pt')
    if os.path.exists(trans_fit_path):
        print("[INFO] Loading existing trans fit...")
        model[cg].load_trans_fit()
    else:
        print("[INFO] Running trans fit (this may take a while)...")
        model[cg].fit_trans(sum_factor_col="sum_factor_new", tolerance=0)
        model[cg].save_trans_fit()
[INFO] Loading existing trans fit...
[LOAD] posterior_samples_trans (modality: gene, 13791 features) ← /cfs/klemming/projects/snic/lappalainen_lab1/users/Leah/data/unpublished/Josie_pilot/processed/fromJosie_Oct2025/SCE_from_Josie/10X_SR/bayesDREAM/output/20251030_GFI1B/posterior_samples_trans_gene.pt
[LOAD] gene.posterior_samples_trans (distribution: negbinom, 13791 features) ← /cfs/klemming/projects/snic/lappalainen_lab1/users/Leah/data/unpublished/Josie_pilot/processed/fromJosie_Oct2025/SCE_from_Josie/10X_SR/bayesDREAM/output/20251030_GFI1B/posterior_samples_trans_gene.pt
[LOAD] Trans fit loaded from /cfs/klemming/projects/snic/lappalainen_lab1/users/Leah/data/unpublished/Josie_pilot/processed/fromJosie_Oct2025/SCE_from_Josie/10X_SR/bayesDREAM/output/20251030_GFI1B
[LOAD] Modalities loaded: ['cis', 'gene']
[INFO] Loading existing trans fit...
[LOAD] posterior_samples_trans (modality: gene, 13791 features) ← /cfs/klemming/projects/snic/lappalainen_lab1/users/Leah/data/unpublished/Josie_pilot/processed/fromJosie_Oct2025/SCE_from_Josie/10X_SR/bayesDREAM/output/20251030_GEMIN5/posterior_samples_trans_gene.pt
[LOAD] gene.posterior_samples_trans (distribution: negbinom, 13791 features) ← /cfs/klemming/projects/snic/lappalainen_lab1/users/Leah/data/unpublished/Josie_pilot/processed/fromJosie_Oct2025/SCE_from_Josie/10X_SR/bayesDREAM/output/20251030_GEMIN5/posterior_samples_trans_gene.pt
[LOAD] Trans fit loaded from /cfs/klemming/projects/snic/lappalainen_lab1/users/Leah/data/unpublished/Josie_pilot/processed/fromJosie_Oct2025/SCE_from_Josie/10X_SR/bayesDREAM/output/20251030_GEMIN5
[LOAD] Modalities loaded: ['cis', 'gene']
[INFO] Loading existing trans fit...
[LOAD] posterior_samples_trans (modality: gene, 13791 features) ← /cfs/klemming/projects/snic/lappalainen_lab1/users/Leah/data/unpublished/Josie_pilot/processed/fromJosie_Oct2025/SCE_from_Josie/10X_SR/bayesDREAM/output/20251030_DDX6/posterior_samples_trans_gene.pt
[LOAD] gene.posterior_samples_trans (distribution: negbinom, 13791 features) ← /cfs/klemming/projects/snic/lappalainen_lab1/users/Leah/data/unpublished/Josie_pilot/processed/fromJosie_Oct2025/SCE_from_Josie/10X_SR/bayesDREAM/output/20251030_DDX6/posterior_samples_trans_gene.pt
[LOAD] Trans fit loaded from /cfs/klemming/projects/snic/lappalainen_lab1/users/Leah/data/unpublished/Josie_pilot/processed/fromJosie_Oct2025/SCE_from_Josie/10X_SR/bayesDREAM/output/20251030_DDX6
[LOAD] Modalities loaded: ['cis', 'gene']
time: 2.5 s (started: 2025-11-03 21:15:31 +01:00)

Plot results¶

Mean v CI plots¶

In [22]:
for cg in ['GFI1B', 'GEMIN5', 'DDX6']:
    samples_alpha = model[cg].posterior_samples_trans['K_a'][:, 0, :].cpu().numpy()
    samples_n     = model[cg].posterior_samples_trans['n_a'][:, 0, :].cpu().numpy()

    # Mean values
    alpha_mean = samples_alpha.mean(axis=0)
    n_mean     = samples_n.mean(axis=0)

    # 95% credible intervals
    alpha_lo, alpha_hi = np.percentile(samples_alpha, [2.5, 97.5], axis=0)
    n_lo, n_hi         = np.percentile(samples_n, [2.5, 97.5], axis=0)

    # 95% CI widths
    alpha_ci_width = alpha_hi - alpha_lo
    n_ci_width     = n_hi - n_lo

    # Dependency mask based on n's CI excluding 0
    dependent_mask = (n_lo > 0) | (n_hi < 0)
    dependent_pct = dependent_mask.mean() * 100

    # === Plot 1: K_a ===
    plt.figure()
    plt.scatter(alpha_mean[~dependent_mask], alpha_ci_width[~dependent_mask],
                s=5, alpha=0.4, color='black', label='not dependent')
    plt.scatter(alpha_mean[dependent_mask], alpha_ci_width[dependent_mask],
                s=5, alpha=0.6, color='blue', label='dependent')
    plt.xlabel(r'Mean $K_a$')
    plt.ylabel(r'95% CI width of $K_a$')
    plt.title(f"{cg} — $K_a$ uncertainty vs mean ({dependent_pct:.1f}% dependent)")
    plt.axhline(0, color='black', linestyle=':', linewidth=1)
    plt.axvline(0, color='black', linestyle=':', linewidth=1)
    plt.legend(frameon=False, loc='center left', bbox_to_anchor=(1.02, 0.5))
    plt.tight_layout()
    plt.show()

    # === Plot 2: n_a ===
    plt.figure()
    plt.scatter(n_mean[~dependent_mask], n_ci_width[~dependent_mask],
                s=5, alpha=0.4, color='black', label='not dependent')
    plt.scatter(n_mean[dependent_mask], n_ci_width[dependent_mask],
                s=5, alpha=0.6, color='blue', label='dependent')
    plt.xlabel(r'Mean $n$ (Hill coefficient)')
    plt.ylabel(r'95% CI width of $n$')
    plt.title(f"{cg} — $n_a$ uncertainty vs mean ({dependent_pct:.1f}% dependent)")
    plt.axhline(0, color='black', linestyle=':', linewidth=1)
    plt.axvline(0, color='black', linestyle=':', linewidth=1)
    plt.legend(frameon=False, loc='center left', bbox_to_anchor=(1.02, 0.5))
    plt.tight_layout()
    plt.show()
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
time: 2.72 s (started: 2025-11-03 20:43:05 +01:00)
In [23]:
def hill_xinf_samples(K_samps, n_samps, tol=0.2, x_max=None):
    S, T = n_samps.shape
    xinf = np.full((S, T), np.nan, dtype=float)
    m = np.abs(n_samps)
    base = (m - 1.0) / (m + 1.0)               # in (0,1) when |n|>1
    valid = m > (1.0 + tol)

    with np.errstate(divide='ignore', invalid='ignore'):
        log_xinf = np.log(K_samps) + (1.0 / n_samps) * np.log(base)
    xinf[valid] = np.exp(log_xinf[valid])

    if x_max is not None:
        xinf[xinf > x_max] = np.nan
    return xinf

for cg in ['GFI1B', 'GEMIN5', 'DDX6']:
    K_samps = model[cg].posterior_samples_trans['K_a'][:, 0, :].detach().cpu().numpy()
    n_samps = model[cg].posterior_samples_trans['n_a'][:, 0, :].detach().cpu().numpy()

    n_mean = n_samps.mean(axis=0)
    n_lo, n_hi = np.percentile(n_samps, [2.5, 97.5], axis=0)
    n_ci_width = n_hi - n_lo
    dependent_mask = (n_lo > 0) | (n_hi < 0)
    dependent_pct = dependent_mask.mean() * 100

    # --- Inflection plot, color by NaN fraction (dark = few NaNs) ---
    xinf_samps = hill_xinf_samples(K_samps, n_samps, tol=0.2, x_max=None)
    xinf_mean = np.nanmean(xinf_samps, axis=0)
    xinf_lo, xinf_hi = np.nanpercentile(xinf_samps, [2.5, 97.5], axis=0)
    xinf_ci_width = xinf_hi - xinf_lo
    frac_nan = np.mean(np.isnan(xinf_samps), axis=0)

    mask = dependent_mask & np.isfinite(xinf_mean) & np.isfinite(xinf_ci_width)

    plt.figure()
    sc = plt.scatter(
        xinf_mean[mask], xinf_ci_width[mask],
        c=frac_nan[mask], cmap='Blues_r', vmin=0, vmax=1,
        s=8, alpha=0.9
    )
    plt.xlabel(r'Mean inflection $x_{\mathrm{inf}}$')
    plt.ylabel(r'95% CI width of $x_{\mathrm{inf}}$')
    plt.title(f"{cg} — inflection (dependent only; {dependent_pct:.1f}% dependent)")
    cbar = plt.colorbar(sc, pad=0.02)
    cbar.set_label('fraction NaN in $x_{\\mathrm{inf}}$ samples (NaN are where abs(n)<1)')
    plt.axhline(0, color='black', linestyle=':', linewidth=1)
    plt.tight_layout()
    plt.show()

    # --- n plot, gene-level points with alpha blending ---
    plt.figure()
    # not dependent (light grey)
    plt.scatter(
        n_mean[~dependent_mask], n_ci_width[~dependent_mask],
        s=5, alpha=0.3, color='grey', label='not dependent'
    )
    # dependent (blue, darker = overlap)
    plt.scatter(
        n_mean[dependent_mask], n_ci_width[dependent_mask],
        s=5, alpha=0.2, color='blue', label='dependent'
    )

    plt.xlabel(r'Mean $n$ (Hill coefficient)')
    plt.ylabel(r'95% CI width of $n$')
    plt.title(f"{cg} — $n$ uncertainty vs mean ({dependent_pct:.1f}% dependent)")
    plt.legend(frameon=False, loc='center left', bbox_to_anchor=(1.02, 0.5))
    plt.axhline(0, color='black', linestyle=':', linewidth=1)
    plt.tight_layout()
    plt.show()
/tmp/ipykernel_2603555/1393261059.py:9: RuntimeWarning: overflow encountered in divide
  log_xinf = np.log(K_samps) + (1.0 / n_samps) * np.log(base)
/tmp/ipykernel_2603555/1393261059.py:28: RuntimeWarning: Mean of empty slice
  xinf_mean = np.nanmean(xinf_samps, axis=0)
/cfs/klemming/projects/snic/lappalainen_lab1/users/Leah/software/miniconda3/envs/pyroenv/lib/python3.12/site-packages/numpy/lib/_nanfunctions_impl.py:1650: RuntimeWarning: All-NaN slice encountered
  return fnb._ureduce(a,
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
time: 5.16 s (started: 2025-11-03 20:43:07 +01:00)

Posterior density lines plots¶

Gene-level parameters¶

In [24]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection
from scipy.stats import gaussian_kde
from matplotlib import cm

def plot_posterior_density_lines(
    samples,                        # [S, T]
    title="Posterior density lines",
    sort_by="median",
    subset_mask=None,
    cmap="viridis",
    alpha_overall=0.5,
    density_gamma=0.7,
    norm_global=True,
    y_quantiles=(0.5, 99.5),
    grid_points=350,
    linewidth=0.8,
    add_median_lines=True,
    y_label=r"$\theta$",
    ax=None,
    show=True,
    y_range=None,                   # <-- NEW
):
    """Plot per-feature posterior densities as vertical colour lines."""
    samples = np.asarray(samples)

    if samples.ndim == 1:
        samples = samples[:, None]
    elif samples.ndim > 2:
        samples = samples.reshape(samples.shape[0], -1)

    S, T = samples.shape

    if subset_mask is not None:
        subset_mask = np.asarray(subset_mask, dtype=bool)
        samples = samples[:, subset_mask]
        S, T = samples.shape

    if sort_by == "median":
        order = np.argsort(np.nanmedian(samples, axis=0))
    elif sort_by == "mean":
        order = np.argsort(np.nanmean(samples, axis=0))
    else:
        order = np.arange(T)
    samples_sorted = samples[:, order]

    # --- y-range: either from samples, or overridden explicitly ---
    if y_range is None:
        y_min, y_max = np.nanpercentile(samples_sorted, y_quantiles)
    else:
        y_min, y_max = y_range

    y_grid = np.linspace(y_min, y_max, grid_points)

    # KDE per feature
    dens_list = []
    for t in range(T):
        vals = samples_sorted[:, t]
        vals = vals[~np.isnan(vals)]
        if vals.size < 2:
            dens = np.zeros_like(y_grid)
        else:
            kde = gaussian_kde(vals)
            dens = kde(y_grid)
        dens_list.append(dens)
    dens_mat = np.stack(dens_list, axis=0)

    if norm_global:
        m = dens_mat.max() + 1e-12
        dens_norm = (dens_mat / m) ** density_gamma
    else:
        m = dens_mat.max(axis=1, keepdims=True) + 1e-12
        dens_norm = (dens_mat / m) ** density_gamma

    L = len(y_grid)
    segs_all, cols_all = [], []
    cmap_obj = cm.get_cmap(cmap)

    for x_pos in range(T):
        x = np.full(L, x_pos, dtype=float)
        pts = np.column_stack([x, y_grid])
        segs = np.stack([pts[:-1], pts[1:]], axis=1)
        c = cmap_obj(dens_norm[x_pos, :-1])
        c[:, 3] = alpha_overall
        segs_all.append(segs)
        cols_all.append(c)

    segs_all = np.concatenate(segs_all, axis=0)
    cols_all = np.concatenate(cols_all, axis=0)

    if ax is None:
        fig, ax = plt.subplots(figsize=(12, 4))
    else:
        fig = ax.figure

    lc = LineCollection(segs_all, colors=cols_all, linewidths=linewidth)
    ax.add_collection(lc)

    if add_median_lines:
        med = np.nanmedian(samples_sorted, axis=0)
        for i, m_val in enumerate(med):
            arr = np.asarray(m_val)
            if arr.size == 0:
                continue
            m_val_scalar = float(arr.ravel()[0])
            if not np.isfinite(m_val_scalar):
                continue
            ax.hlines(
                m_val_scalar,
                i - 0.4,
                i + 0.4,
                color="white",
                linewidth=0.5,
                alpha=0.9,
            )

    ax.set_xlim(-0.5, T - 0.5)
    ax.set_ylim(y_min, y_max)
    ax.axhline(0, color='black', linestyle=':', linewidth=1)
    ax.set_xlabel("Genes (equal-width bins)")
    ax.set_ylabel(y_label)
    if title:
        ax.set_title(title)
    ax.set_xticks([])

    if show:
        fig.tight_layout()
        plt.show()

    return ax


def dependency_mask_from_n(n_samps, ci=95.0):
    """95% CI of n excludes 0."""
    lo_q = (100 - ci) / 2.0
    hi_q = 100 - lo_q
    lo, hi = np.percentile(n_samps, [lo_q, hi_q], axis=0)
    return (lo > 0) | (hi < 0)

def abs_n_gt_tol_mask(n_samps, tol=1.0):
    """|median(n)| > 1+tol_n_for_xinf  (tol is the extra beyond 1)."""
    med_abs = np.abs(np.median(n_samps, axis=0))
    return med_abs > (1.0 + tol)

def hill_xinf_samples(K_samps, n_samps, tol_n=0.0):
    """
    Compute per-sample x_inf for Hill curves.
    K_samps, n_samps: [S, T]
    Returns [S, T] with NaN where |n| <= 1+tol_n.
    """
    m = np.abs(n_samps)
    base = (m - 1.0) / (m + 1.0)

    with np.errstate(divide='ignore', invalid='ignore'):
        log_xinf = np.log(K_samps) + (1.0 / n_samps) * np.log(base)
        xinf = np.exp(log_xinf)

    xinf[m <= (1.0 + tol_n)] = np.nan
    return xinf

def hill_y(x, A, alpha, K, n, eps=1e-8):
    """
    Vectorized Hill function for arrays (broadcasting OK).
    A, alpha, K, n can be [S, T]; x can be scalar or array.
    """
    x = np.asarray(x, dtype=float)
    # ensure broadcasting: add trailing axes if needed
    while A.ndim > x.ndim:
        x = np.expand_dims(x, axis=0)
    x_n = np.power(x, n)
    K_n = np.power(K, n)
    y = A + alpha * x_n / (K_n + x_n + eps)
    return y

import numpy as np

def compute_log2fc_metrics(A_samps, alpha_samps, Vmax_samps, K_samps, n_samps,
                           x_true_samps, eps=1e-6, n_zero_tol=1e-6):
    """
    Directional log2 fold-change metrics for:

        y(x) = A + alpha * Vmax * x^n / (K^n + x^n)

    log2fc_full:  between y(x→∞) and y(x→0), with sign determined by n.
    log2fc_obs:   between y(x_max_obs) and y(x_min_obs).
    """
    # ensure arrays
    A_samps     = np.asarray(A_samps)
    alpha_samps = np.asarray(alpha_samps)
    Vmax_samps  = np.asarray(Vmax_samps)
    K_samps     = np.asarray(K_samps)
    n_samps     = np.asarray(n_samps)

    # --- observed x range from mean x_true across samples per cell ---
    X = np.asarray(x_true_samps)              # [S, N_cells]
    x_means_per_cell = X.mean(axis=0)         # [N_cells]
    x_min_obs = float(x_means_per_cell.min())
    x_max_obs = float(x_means_per_cell.max())

    A     = A_samps
    alpha = alpha_samps
    Vmax  = Vmax_samps

    # ---------------- full-range FC with n sign ----------------
    # sign of n determines whether y increases or decreases with x
    n_sign = np.sign(n_samps)
    # treat near-zero n as flat (no direction)
    flat_mask = np.abs(n_samps) < n_zero_tol
    n_sign[flat_mask] = 0.0

    # asymptotes:
    # if n > 0:  y(0) = A,             y(∞) = A + alpha*Vmax
    # if n < 0:  y(0) = A + alpha*Vmax, y(∞) = A
    y0_full   = np.where(n_sign >= 0, A, A + alpha * Vmax)
    yinf_full = np.where(n_sign >= 0, A + alpha * Vmax, A)

    log2fc_full = np.zeros_like(A, dtype=float)
    changing_mask = n_sign != 0.0
    log2fc_full[changing_mask] = np.log2(
        (yinf_full[changing_mask] + eps) /
        (y0_full[changing_mask]   + eps)
    )
    # flat_mask stays at 0

    # ---------------- helper: y(x) under this parametrisation ----------------
    def y_hill(x_scalar, A, alpha, Vmax, K, n, eps_inner=1e-8):
        """
        y(x) = A + alpha * Vmax * x^n / (K^n + x^n), evaluated at scalar x.
        """
        x = float(x_scalar)
        x_safe = x + eps_inner
        K_safe = K + eps_inner

        with np.errstate(divide='ignore', invalid='ignore'):
            x_log = np.log(x_safe)
            K_log = np.log(K_safe)
            x_n = np.exp(n * x_log)
            K_n = np.exp(n * K_log)

        frac = x_n / (K_n + x_n + eps_inner)
        h = Vmax * frac
        return A + alpha * h

    # ---------------- observed-range FC: x_min_obs -> x_max_obs ----------------
    Y_min_obs = y_hill(x_min_obs, A, alpha, Vmax, K_samps, n_samps, eps_inner=eps)
    Y_max_obs = y_hill(x_max_obs, A, alpha, Vmax, K_samps, n_samps, eps_inner=eps)

    log2fc_obs = np.log2((Y_max_obs + eps) / (Y_min_obs + eps))

    return log2fc_full, log2fc_obs, x_min_obs, x_max_obs
time: 4.68 ms (started: 2025-11-03 20:43:13 +01:00)
In [25]:
for cg in ['GFI1B', 'GEMIN5', 'DDX6']:
    n_samps = model[cg].posterior_samples_trans['n_a'][:, 0, :].detach().cpu().numpy()
    dep_mask = dependency_mask_from_n(n_samps)

    plot_posterior_density_lines(
        n_samps,
        title=f"{cg} — posterior of $n$",
        subset_mask=dep_mask,
        cmap="viridis",
        alpha_overall=0.45,
        density_gamma=0.7,
        add_median_lines=True,
        y_label=r"$n$ (Hill coefficient)"
    )
/tmp/ipykernel_2603555/114306776.py:78: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed in 3.11. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap()`` or ``pyplot.get_cmap()`` instead.
  cmap_obj = cm.get_cmap(cmap)
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
time: 2min 7s (started: 2025-11-03 20:43:13 +01:00)
In [26]:
tol_n_for_xinf = 0.2  # extra margin beyond 1

def log2_pos(a):
    a = np.asarray(a)
    out = np.full_like(a, np.nan, dtype=float)
    mask = a > 0
    out[mask] = np.log2(a[mask])
    return out

for cg in ['GFI1B', 'GEMIN5', 'DDX6']:
    K_samps = model[cg].posterior_samples_trans['K_a'][:, 0, :].detach().cpu().numpy()
    n_samps = model[cg].posterior_samples_trans['n_a'][:, 0, :].detach().cpu().numpy()

    dep_mask   = dependency_mask_from_n(n_samps)
    abs_gt_tol = abs_n_gt_tol_mask(n_samps, tol=tol_n_for_xinf)
    mask = dep_mask & abs_gt_tol

    xinf_samps = hill_xinf_samples(K_samps, n_samps, tol_n=tol_n_for_xinf)
    log2_xinf_samps = log2_pos(xinf_samps)

    # also drop genes where xinf is NaN for all samples
    mask &= ~np.all(np.isnan(xinf_samps), axis=0)

    plot_posterior_density_lines(
        log2_xinf_samps,
        title=f"{cg} — posterior of $x_{{\\mathrm{{inf}}}}$ (dependent, |n|>1+{tol_n_for_xinf})",
        subset_mask=mask,
        cmap="viridis",
        alpha_overall=0.45,
        density_gamma=0.7,
        add_median_lines=True,
        y_label=r'$\log_2 x_{\mathrm{inf}}$'
    )
/tmp/ipykernel_2603555/114306776.py:156: RuntimeWarning: overflow encountered in divide
  log_xinf = np.log(K_samps) + (1.0 / n_samps) * np.log(base)
/tmp/ipykernel_2603555/114306776.py:78: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed in 3.11. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap()`` or ``pyplot.get_cmap()`` instead.
  cmap_obj = cm.get_cmap(cmap)
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
time: 1min 54s (started: 2025-11-03 20:45:21 +01:00)
In [27]:
for cg in ['GFI1B', 'GEMIN5', 'DDX6']:
    K_samps = model[cg].posterior_samples_trans['K_a'][:, 0, :].detach().cpu().numpy()
    n_samps = model[cg].posterior_samples_trans['n_a'][:, 0, :].detach().cpu().numpy()
    dep_mask = dependency_mask_from_n(n_samps)
    log2_K_samps = log2_pos(K_samps)

    plot_posterior_density_lines(
        log2_K_samps,
        title=f"{cg} — posterior of $K_a$ (dependent only)",
        subset_mask=dep_mask,
        cmap="viridis",
        alpha_overall=0.45,
        density_gamma=0.7,
        add_median_lines=True,
        y_label=r"$\log_2 K_a$"
    )
/tmp/ipykernel_2603555/114306776.py:78: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed in 3.11. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap()`` or ``pyplot.get_cmap()`` instead.
  cmap_obj = cm.get_cmap(cmap)
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
time: 2min (started: 2025-11-03 20:47:15 +01:00)
In [28]:
for cg in ['GFI1B', 'GEMIN5', 'DDX6']:
    # --- Extract posterior samples ---
    A_samps     = model[cg].posterior_samples_trans['A'][:, 0, :].cpu().numpy()
    alpha_samps = model[cg].posterior_samples_trans['alpha'][:, 0, :].cpu().numpy()
    Vmax_samps  = model[cg].posterior_samples_trans['Vmax_a'][:, 0, :].cpu().numpy()
    K_samps     = model[cg].posterior_samples_trans['K_a'][:, 0, :].cpu().numpy()
    n_samps     = model[cg].posterior_samples_trans['n_a'][:, 0, :].cpu().numpy()
    x_true_samps = model[cg].x_true.detach().cpu().numpy()
    
    # --- Compute log2FC metrics ---
    log2fc_full, log2fc_obs, x_min_obs, x_max_obs = compute_log2fc_metrics(
        A_samps, alpha_samps, Vmax_samps, K_samps, n_samps, x_true_samps
    )

    # --- Gene-level means ---
    full_mean = np.nanmean(log2fc_full, axis=0)
    obs_mean  = np.nanmean(log2fc_obs, axis=0)

    # --- Dependency mask (95% CI of n excludes 0) ---
    lo, hi = np.percentile(n_samps, [2.5, 97.5], axis=0)
    dep_mask = (lo > 0) | (hi < 0)

    # ============================================================
    # 1️⃣  Correlation between full-range and observed log2FC
    # ============================================================
    plt.figure(figsize=(5.5, 5))
    plt.scatter(full_mean[~dep_mask], obs_mean[~dep_mask],
                s=10, alpha=0.25, color='grey', label='not dependent')
    plt.scatter(full_mean[dep_mask],  obs_mean[dep_mask],
                s=10, alpha=0.6, color='blue', label='dependent')

    # 1:1 line
    lim_min = min(full_mean.min(), obs_mean.min())
    lim_max = max(full_mean.max(), obs_mean.max())
    plt.plot([lim_min, lim_max], [lim_min, lim_max],
             color='black', linestyle=':', linewidth=1)

    plt.xlabel(r'log$_2$ Fold-Change (full dynamic range: $A \rightarrow A + \alpha V_{\max}$)')
    plt.ylabel(r'log$_2$ Fold-Change (within observed $x_{\min} \rightarrow x_{\max}$)')
    plt.title(f"{cg}: Relationship between full and observed dynamic range")
    plt.legend(frameon=False, loc='best')
    plt.grid(True, linewidth=0.5, alpha=0.4)
    plt.tight_layout()
    plt.show()

    # ============================================================
    # 2️⃣  Distribution of full-range log₂FC
    # ============================================================
    plot_posterior_density_lines(
        log2fc_full,
        title=fr"{cg} — Posterior distribution of log$_2$ Fold-Change (full dynamic range)",
        subset_mask=dep_mask,
        cmap="viridis",
        alpha_overall=0.45,
        density_gamma=0.7,
        add_median_lines=True,
        y_label=r"log$_2$ Fold-Change (full dynamic range: $A \rightarrow A + \alpha V_{\max}$)",
    )

    # ============================================================
    # 3️⃣  Distribution of observed-range log₂FC
    # ============================================================
    plot_posterior_density_lines(
        log2fc_obs,
        title=fr"{cg} — Posterior distribution of log$_2$ Fold-Change (within observed $x$-range)",
        subset_mask=dep_mask,
        cmap="viridis",
        alpha_overall=0.45,
        density_gamma=0.7,
        add_median_lines=True,
        y_label=r"log$_2$ Fold-Change (observed range: $x_{\min} \rightarrow x_{\max}$)",
    )
No description has been provided for this image
/tmp/ipykernel_2603555/114306776.py:78: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed in 3.11. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap()`` or ``pyplot.get_cmap()`` instead.
  cmap_obj = cm.get_cmap(cmap)
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
time: 4min 51s (started: 2025-11-03 20:49:16 +01:00)

Cell-level parameters¶

In [29]:
from matplotlib.lines import Line2D
from matplotlib.transforms import Bbox

def plot_xtrue_density_by_guide(
    model,
    cg,
    log2=False,
    cmap="viridis",
    alpha_overall=0.5,
    density_gamma=0.7,
    norm_global=True,
    y_quantiles=(0.5, 99.5),
    grid_points=350,
    linewidth=0.8,
    group_by_guide=True,
):
    """
    One vertical density line per *cell* for x_true, matching the style of
    plot_posterior_density_lines, with guides indicated by:

      - coloured horizontal median ticks per cell
      - a coloured bar between title and axes showing guide per cell
      - a legend mapping colour -> guide

    group_by_guide:
        True  -> cells grouped by guide, then median within guide
        False -> cells ordered only by median, but colours still show guide ID.
    """
    df = model[cg].meta.copy()
    X = to_np(model[cg].x_true)  # [S, N_cells]

    # log2 transform without dropping guides
    if log2:
        eps = 1e-6
        X = np.log2(np.maximum(X, eps))

    samples = np.asarray(X)       # [S, N]
    S, N = samples.shape
    guides = df['guide'].astype(str).to_numpy()   # length N

    # ---------- choose ordering ----------
    med_per_cell = np.nanmedian(samples, axis=0)

    if group_by_guide:
        # order cells: by guide, then median within guide
        def guide_sort_key(g):
            parts = g.split('_')
            root = parts[0]
            idx  = int(parts[1]) if len(parts) > 1 and parts[1].isdigit() else 0
            return (root, idx)

        unique_guides = sorted(np.unique(guides), key=guide_sort_key)
        guide_block_rank = {g: i for i, g in enumerate(unique_guides)}
        guide_ranks = np.array([guide_block_rank[g] for g in guides])

        order = np.lexsort((med_per_cell, guide_ranks))  # (N,)
    else:
        unique_guides = sorted(np.unique(guides))
        order = np.argsort(med_per_cell)

    samples_sorted = samples[:, order]
    guides_sorted  = guides[order]

    # ---------- draw density background using generic function ----------
    ylabel = "x_true" + (" (log₂)" if log2 else "")

    # no axes title (we'll use a figure-level title instead)
    ax = plot_posterior_density_lines(
        samples_sorted,
        title="",                    # <-- important: no axis title
        sort_by=None,
        subset_mask=None,
        cmap=cmap,
        alpha_overall=alpha_overall,
        density_gamma=density_gamma,
        norm_global=norm_global,
        y_quantiles=y_quantiles,
        grid_points=grid_points,
        linewidth=linewidth,
        add_median_lines=False,
        y_label=ylabel,
        ax=None,
        show=False,                  # we'll manage layout
    )

    fig = ax.figure
    # First, lay out the main axes nicely
    fig.tight_layout(rect=[0, 0, 0.98, 0.93])  # leave some top margin

    # Now get the *final* axis position after tight_layout
    from matplotlib.transforms import Bbox
    pos = ax.get_position()

    # ---------- coloured median ticks per cell (guide-coded) ----------
    med = np.nanmedian(samples_sorted, axis=0)
    for i, m_val in enumerate(med):
        arr = np.asarray(m_val)
        if arr.size == 0:
            continue
        m_val_scalar = float(arr.ravel()[0])
        if not np.isfinite(m_val_scalar):
            continue

        g = str(guides_sorted[i])
        tick_color = guide_colors.get(g, (1.0, 1.0, 1.0, 1.0))

        ax.hlines(
            m_val_scalar,
            i - 0.4,
            i + 0.4,
            color=tick_color,
            linewidth=0.8,
            alpha=0.9,
            linestyle="solid",
            zorder=3,
        )

    ax.set_xlim(-0.5, N - 0.5)
    ax.set_xlabel("Cells (grouped by guide)" if group_by_guide else "Cells (ordered by median)")
    ax.axhline(0, color='black', linestyle=':', linewidth=1)

    # ---------- coloured bar between title and axes ----------
    bar_height_frac = 0.06   # ~6% of axis height
    bar_gap_frac    = 0.02   # small gap above axes

    bar_bottom = pos.y1 + (pos.height * bar_gap_frac)
    bar_top    = bar_bottom + (pos.height * bar_height_frac)
    bar_pos    = Bbox.from_extents(pos.x0, bar_bottom, pos.x1, bar_top)

    bar_ax = fig.add_axes(bar_pos)
    bar_ax.set_xlim(-0.5, N - 0.5)
    bar_ax.set_ylim(0, 1)
    bar_ax.axis("off")

    # contiguous runs of the same guide along x
    start = 0
    current = guides_sorted[0]
    segments = []
    for i in range(1, N):
        if guides_sorted[i] != current:
            segments.append((start, i - 1, current))
            start = i
            current = guides_sorted[i]
    segments.append((start, N - 1, current))

    for s, e, g in segments:
        color = guide_colors.get(g, "black")
        bar_ax.axvspan(s - 0.5, e + 0.5, color=color)

    # ---------- legend ----------
    handles = []
    labels  = []
    for g in unique_guides:
        color = guide_colors.get(g, 'black')
        handles.append(Line2D([0], [0], color=color, lw=3))
        labels.append(g)
    ax.legend(handles, labels, title="guide", frameon=False,
              bbox_to_anchor=(1.02, 0.5), loc="center left")

    # figure-level title at the very top
    fig.suptitle(f"{cg}: posterior of x_true per cell", y=0.99)

    # no more tight_layout calls here
    plt.show()
time: 3.02 ms (started: 2025-11-03 20:54:07 +01:00)
In [30]:
for cg in ['GFI1B', 'GEMIN5', 'DDX6']:
    plot_xtrue_density_by_guide(model, cg, log2=True, group_by_guide=True)
    plot_xtrue_density_by_guide(model, cg, log2=True, group_by_guide=False)
/tmp/ipykernel_2603555/114306776.py:78: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed in 3.11. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap()`` or ``pyplot.get_cmap()`` instead.
  cmap_obj = cm.get_cmap(cmap)
No description has been provided for this image
/tmp/ipykernel_2603555/114306776.py:78: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed in 3.11. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap()`` or ``pyplot.get_cmap()`` instead.
  cmap_obj = cm.get_cmap(cmap)
No description has been provided for this image
/tmp/ipykernel_2603555/114306776.py:78: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed in 3.11. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap()`` or ``pyplot.get_cmap()`` instead.
  cmap_obj = cm.get_cmap(cmap)
No description has been provided for this image
/tmp/ipykernel_2603555/114306776.py:78: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed in 3.11. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap()`` or ``pyplot.get_cmap()`` instead.
  cmap_obj = cm.get_cmap(cmap)
No description has been provided for this image
/tmp/ipykernel_2603555/114306776.py:78: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed in 3.11. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap()`` or ``pyplot.get_cmap()`` instead.
  cmap_obj = cm.get_cmap(cmap)
No description has been provided for this image
/tmp/ipykernel_2603555/114306776.py:78: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed in 3.11. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap()`` or ``pyplot.get_cmap()`` instead.
  cmap_obj = cm.get_cmap(cmap)
No description has been provided for this image
time: 2min 11s (started: 2025-11-03 20:54:07 +01:00)
In [31]:
from scipy.stats import gaussian_kde

tol_n_for_xinf = 0.2  # extra margin beyond 1

def log2_pos(a):
    a = np.asarray(a)
    out = np.full_like(a, np.nan, dtype=float)
    mask = a > 0
    out[mask] = np.log2(a[mask])
    return out

for cg in ['GFI1B', 'GEMIN5', 'DDX6']:
    K_samps = model[cg].posterior_samples_trans['K_a'][:, 0, :].detach().cpu().numpy()
    n_samps = model[cg].posterior_samples_trans['n_a'][:, 0, :].detach().cpu().numpy()

    dep_mask   = dependency_mask_from_n(n_samps)
    abs_gt_tol = abs_n_gt_tol_mask(n_samps, tol=tol_n_for_xinf)
    mask = dep_mask & abs_gt_tol

    xinf_samps      = hill_xinf_samples(K_samps, n_samps, tol_n=tol_n_for_xinf)
    log2_xinf_samps = log2_pos(xinf_samps)

    # drop genes where x_inf is NaN for all samples
    mask &= ~np.all(np.isnan(xinf_samps), axis=0)

    # ---- compute log2 mean x_true per cell & its y-range ----
    df_meta = model[cg].meta.copy()
    x_true_samps = model[cg].x_true.detach().cpu().numpy()   # [S, N_cells]
    xtrue_mean_per_cell = x_true_samps.mean(axis=0)          # [N_cells]
    log2_xtrue_means = log2_pos(xtrue_mean_per_cell)
    vals_all = log2_xtrue_means[~np.isnan(log2_xtrue_means)]

    # y-range driven by x_true distribution (e.g. central 99% of cells)
    y_min = np.percentile(vals_all, 0.5)
    y_max = np.percentile(vals_all, 99.5)
    y_range = (y_min, y_max)

    # ---- global 95% CI of log2 x_inf (for reference) ----
    inf_vals_all = log2_xinf_samps[:, mask]
    inf_vals_all = inf_vals_all[~np.isnan(inf_vals_all)]
    if inf_vals_all.size > 0:
        ci_lo, ci_hi = np.percentile(inf_vals_all, [2.5, 97.5])
    else:
        ci_lo, ci_hi = y_min, y_max

    # ---------------- figure with 2 panels: main + side density ----------------
    fig, (ax_main, ax_side) = plt.subplots(
        1, 2,
        figsize=(8, 5),
        gridspec_kw={"width_ratios": [4, 1], "wspace": 0.05},
        sharey=True,
    )

    # ----- main posterior density of log2 x_inf -----
    ax_main = plot_posterior_density_lines(
        log2_xinf_samps,
        title=rf"{cg} — posterior of $\log_2 x_{{\mathrm{{inf}}}}$ (dependent, |n|>1+{tol_n_for_xinf})",
        subset_mask=mask,
        cmap="viridis",
        alpha_overall=0.45,
        density_gamma=0.7,
        add_median_lines=True,
        y_label=r'$\log_2 x_{\mathrm{inf}}$',
        ax=ax_main,
        show=False,
        y_range=y_range,   # <-- use x_true-driven scale
    )
    ax_main.set_xlabel("Genes (equal-width bins)")
    ax_main.set_ylim(y_min, y_max)

    # left y ticks
    ax_main.set_yticks(np.linspace(y_min, y_max, 5))
    ax_main.yaxis.set_ticks_position('left')
    ax_main.tick_params(axis='y', which='both', length=4)

    # Optionally indicate global 95% CI region for x_inf
    ax_main.axhline(ci_lo, color='white', linestyle=':', linewidth=0.7, alpha=0.7)
    ax_main.axhline(ci_hi, color='white', linestyle=':', linewidth=0.7, alpha=0.7)

    # ----- sideways densities of log2 mean x_true by target -----
    ax_side.set_xlabel(r'density of $\log_2 x_{\mathrm{true}}$', fontsize=9)
    ax_side.xaxis.set_label_position('top')

    targets = df_meta['target'].astype(str).to_numpy()
    uniq_targets = sorted(np.unique(targets))

    y_grid = np.linspace(y_min, y_max, 400)

    for t in uniq_targets:
        mask_t = targets == t
        vals_t = log2_xtrue_means[mask_t]
        vals_t = vals_t[~np.isnan(vals_t)]
        if vals_t.size == 0:
            continue

        color = target_colors.get(t, 'grey')

        if vals_t.size < 2:
            # tiny bump if no variance
            y0 = vals_t[0]
            bump = np.exp(-0.5 * ((y_grid - y0) / 0.05) ** 2)
            bump /= bump.max() + 1e-12
            ax_side.fill_betweenx(y_grid, 0, bump, color=color, alpha=0.45)
            ax_side.plot(bump, y_grid, color=color, linewidth=1.0)
        else:
            kde = gaussian_kde(vals_t)
            dens_t = kde(y_grid)
            dens_t /= dens_t.max() + 1e-12
            ax_side.fill_betweenx(y_grid, 0, dens_t, color=color, alpha=0.45)
            ax_side.plot(dens_t, y_grid, color=color, linewidth=1.0)

    ax_side.set_xlim(0, 1.05)

    # mirror the y-axis on the right for sanity
    ax_side.yaxis.set_ticks_position('right')
    ax_side.yaxis.set_label_position('right')
    ax_side.set_yticks(np.linspace(y_min, y_max, 5))
    ax_side.tick_params(axis='y', which='both', length=4)
    ax_side.set_ylabel("")  # keep only left label to avoid clutter

    fig.tight_layout()
    plt.show()
/tmp/ipykernel_2603555/114306776.py:156: RuntimeWarning: overflow encountered in divide
  log_xinf = np.log(K_samps) + (1.0 / n_samps) * np.log(base)
/tmp/ipykernel_2603555/114306776.py:78: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed in 3.11. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap()`` or ``pyplot.get_cmap()`` instead.
  cmap_obj = cm.get_cmap(cmap)
/tmp/ipykernel_2603555/4235410720.py:121: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.
  fig.tight_layout()
No description has been provided for this image
/tmp/ipykernel_2603555/114306776.py:156: RuntimeWarning: overflow encountered in divide
  log_xinf = np.log(K_samps) + (1.0 / n_samps) * np.log(base)
/tmp/ipykernel_2603555/114306776.py:78: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed in 3.11. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap()`` or ``pyplot.get_cmap()`` instead.
  cmap_obj = cm.get_cmap(cmap)
/tmp/ipykernel_2603555/4235410720.py:121: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.
  fig.tight_layout()
No description has been provided for this image
/tmp/ipykernel_2603555/114306776.py:156: RuntimeWarning: overflow encountered in divide
  log_xinf = np.log(K_samps) + (1.0 / n_samps) * np.log(base)
/tmp/ipykernel_2603555/114306776.py:78: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed in 3.11. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap()`` or ``pyplot.get_cmap()`` instead.
  cmap_obj = cm.get_cmap(cmap)
/tmp/ipykernel_2603555/4235410720.py:121: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.
  fig.tight_layout()
No description has been provided for this image
time: 2min 15s (started: 2025-11-03 20:56:18 +01:00)

Mean results plots¶

In [32]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm

# assumes:
# - compute_log2fc_metrics is already defined (directional)
# - dependency_mask_from_n is already defined
# - target_colors dict exists, e.g. {'NTC': ..., 'GFI1B': ..., ...}

# -----------------------------
# x_inf helper
# -----------------------------
def hill_xinf_samples(K_samps, n_samps, tol_n=0.0):
    """
    Compute per-sample x_inf for Hill curves.

    For the positive Hill function using (K, n), the point of inflexion is:
        x_inf = K * ((|n|-1)/(|n|+1))^(1/n)
    This only makes sense when |n| > 1 + tol_n. Otherwise we return NaN.

    K_samps, n_samps: [S, T]
    Returns:
        xinf: [S, T] with NaN where |n| <= 1+tol_n.
    """
    K_samps = np.asarray(K_samps)
    n_samps = np.asarray(n_samps)

    m = np.abs(n_samps)
    base = (m - 1.0) / (m + 1.0)

    with np.errstate(divide='ignore', invalid='ignore'):
        log_xinf = np.log(K_samps) + (1.0 / n_samps) * np.log(base)
        xinf = np.exp(log_xinf)

    xinf[m <= (1.0 + tol_n)] = np.nan
    return xinf

# -----------------------------
# generic scatter helper
# -----------------------------
def scatter_nice_dep_plot(
    x_vals,
    y_vals,
    dep_mask,
    cg,
    xlabel,
    ylabel,
    title,
    target_colors,
    alpha=0.3,
    s=5,
    add_zero_guides=True,
):
    """
    Make a dependency-coloured scatter with a square plotting area
    and the legend placed fully outside (to the right).
    """
    import matplotlib.cm as cm
    import matplotlib.pyplot as plt
    import numpy as np

    x_vals = np.asarray(x_vals)
    y_vals = np.asarray(y_vals)
    dep_mask = np.asarray(dep_mask, dtype=bool)

    finite = np.isfinite(x_vals) & np.isfinite(y_vals)
    if not np.any(finite):
        print(f"[{cg}] No finite points for this plot.")
        return

    x = x_vals[finite]
    y = y_vals[finite]
    dep = dep_mask[finite]
    nondep = ~dep

    color_ntc = target_colors.get("NTC", cm.Greys(0.6))
    color_cg = target_colors.get(cg, "blue")

    # make a *square plot box*, leave right margin for legend
    fig, ax = plt.subplots(figsize=(5.5, 5.5))
    ax.set_box_aspect(1)  # ensures square axes box

    # scatter points
    ax.scatter(
        x[nondep],
        y[nondep],
        s=s,
        alpha=alpha,
        color=color_ntc,
        label="non-dependent",
    )
    ax.scatter(
        x[dep],
        y[dep],
        s=s,
        alpha=alpha,
        color=color_cg,
        label=f"{cg} dependent",
    )

    # zero guides
    if add_zero_guides:
        ax.axhline(0, color="black", linestyle=":", linewidth=1)
        ax.axvline(0, color="black", linestyle=":", linewidth=1)

    # labels and title
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.set_title(title)

    # legend outside, keeping the plot perfectly square
    leg = ax.legend(
        frameon=False,
        loc="center left",
        bbox_to_anchor=(1.02, 0.5),  # move outside right
        borderaxespad=0.0,
    )
    for lh in leg.legend_handles:
        try:
            lh.set_sizes([50])
        except Exception:
            pass

    # tidy layout — prevents label cutoff but doesn’t distort square axes
    fig.subplots_adjust(right=0.78)  # reserve space for legend
    ax.grid(True, linewidth=0.5, alpha=0.4)

    plt.show()


# -----------------------------
# main loop over cis genes
# -----------------------------
# -----------------------------
# main loop over cis genes
# -----------------------------
tol_n_for_xinf = 0.2  # extra margin beyond 1 for x_inf

for cg in ['GFI1B', 'GEMIN5', 'DDX6']:
    print(f"\n=== {cg}: FC/n/x_inf relationships ===")

    # posterior samples
    A_samps      = model[cg].posterior_samples_trans['A'][:, 0, :].detach().cpu().numpy()
    alpha_samps  = model[cg].posterior_samples_trans['alpha'][:, 0, :].detach().cpu().numpy()
    Vmax_samps   = model[cg].posterior_samples_trans['Vmax_a'][:, 0, :].detach().cpu().numpy()
    K_samps      = model[cg].posterior_samples_trans['K_a'][:, 0, :].detach().cpu().numpy()
    n_samps      = model[cg].posterior_samples_trans['n_a'][:, 0, :].detach().cpu().numpy()
    x_true_samps = model[cg].x_true.detach().cpu().numpy()

    # directional log2FCs from the Hill model
    log2fc_full, log2fc_obs, x_min_obs, x_max_obs = compute_log2fc_metrics(
        A_samps, alpha_samps, Vmax_samps, K_samps, n_samps, x_true_samps
    )

    # per-gene means
    n_mean             = np.mean(n_samps, axis=0)         # [T]
    log2fc_full_mean   = np.mean(log2fc_full, axis=0)     # [T]

    # dependency mask
    dep_mask = dependency_mask_from_n(n_samps)
    dep_pct  = 100.0 * np.sum(dep_mask) / len(dep_mask)

    # x_inf samples + mean, then log2-transform
    xinf_samps   = hill_xinf_samples(K_samps, n_samps, tol_n=tol_n_for_xinf)  # [S, T]
    xinf_mean    = np.nanmean(xinf_samps, axis=0)  # [T]
    log2_xinf_mean = np.log2(xinf_mean)
    xinf_finite_mask = np.isfinite(log2_xinf_mean)

    # -----------------------------
    # 1) mean n vs mean full-range log2FC
    # -----------------------------
    scatter_nice_dep_plot(
        x_vals=n_mean,
        y_vals=log2fc_full_mean,
        dep_mask=dep_mask,
        cg=cg,
        target_colors=target_colors,
        xlabel=rf"mean $n$ (Hill coefficient)",
        ylabel=rf"mean log$_2$FC ($y(x\to\infty)$ vs $y(x\to 0)$)",
        title=f"{cg}: full-range log$_2$FC vs Hill coefficient\n({dep_pct:.1f}% dependent)",
    )

    # -----------------------------
    # 2) mean n vs log2(x_inf)
    # -----------------------------
    dep_mask_xinf = dep_mask & xinf_finite_mask
    scatter_nice_dep_plot(
        x_vals=n_mean[xinf_finite_mask],
        y_vals=log2_xinf_mean[xinf_finite_mask],
        dep_mask=dep_mask_xinf[xinf_finite_mask],
        cg=cg,
        target_colors=target_colors,
        xlabel=rf"mean $n$ (Hill coefficient)",
        ylabel=rf"log$_2(x_{{\mathrm{{inf}}}})$",
        title=f"{cg}: log$_2(x_{{\mathrm{{inf}}}})$ vs Hill coefficient\n(only |n|>1+{tol_n_for_xinf})",
    )

    # -----------------------------
    # 3) mean full-range log2FC vs log2(x_inf)
    # -----------------------------
    scatter_nice_dep_plot(
        x_vals=log2fc_full_mean[xinf_finite_mask],
        y_vals=log2_xinf_mean[xinf_finite_mask],
        dep_mask=dep_mask_xinf[xinf_finite_mask],
        cg=cg,
        target_colors=target_colors,
        xlabel=rf"mean log$_2$FC ($y(x\to\infty)$ vs $y(x\to 0)$)",
        ylabel=rf"log$_2(x_{{\mathrm{{inf}}}})$",
        title=f"{cg}: log$_2(x_{{\mathrm{{inf}}}})$ vs full-range log$_2$FC\n(only |n|>1+{tol_n_for_xinf})",
    )
=== GFI1B: FC/n/x_inf relationships ===
<>:195: SyntaxWarning: invalid escape sequence '\m'
<>:209: SyntaxWarning: invalid escape sequence '\m'
<>:195: SyntaxWarning: invalid escape sequence '\m'
<>:209: SyntaxWarning: invalid escape sequence '\m'
/tmp/ipykernel_2603555/2547606006.py:195: SyntaxWarning: invalid escape sequence '\m'
  title=f"{cg}: log$_2(x_{{\mathrm{{inf}}}})$ vs Hill coefficient\n(only |n|>1+{tol_n_for_xinf})",
/tmp/ipykernel_2603555/2547606006.py:209: SyntaxWarning: invalid escape sequence '\m'
  title=f"{cg}: log$_2(x_{{\mathrm{{inf}}}})$ vs full-range log$_2$FC\n(only |n|>1+{tol_n_for_xinf})",
/tmp/ipykernel_2603555/2547606006.py:32: RuntimeWarning: overflow encountered in divide
  log_xinf = np.log(K_samps) + (1.0 / n_samps) * np.log(base)
/tmp/ipykernel_2603555/2547606006.py:165: RuntimeWarning: Mean of empty slice
  xinf_mean    = np.nanmean(xinf_samps, axis=0)  # [T]
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
=== GEMIN5: FC/n/x_inf relationships ===
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
=== DDX6: FC/n/x_inf relationships ===
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
time: 6.78 s (started: 2025-11-03 20:58:34 +01:00)

Compare to edgeR results¶

comparison plots¶

In [22]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.lines import Line2D

# --------------------------------------------------------------------
# Basic helpers
# --------------------------------------------------------------------
def lighten(color, amount=0.3):
    """Lighten an RGBA/RGB colour by mixing with white."""
    c = np.array(mcolors.to_rgba(color))
    white = np.array([1, 1, 1, 1])
    return tuple((1 - amount) * c + amount * white)

def darken(color, amount=0.3):
    """Darken an RGBA/RGB colour by mixing with black."""
    c = np.array(mcolors.to_rgba(color))
    black = np.array([0, 0, 0, 1])
    return tuple((1 - amount) * c + amount * black)

def dependency_mask_from_n(n_samps, ci=95.0):
    """
    Dependency mask based on n: 95% CI excludes 0.
    n_samps: [S, T]
    """
    lo_q = (100 - ci) / 2.0
    hi_q = 100 - lo_q
    lo, hi = np.percentile(n_samps, [lo_q, hi_q], axis=0)
    return (lo > 0) | (hi < 0)


# --------------------------------------------------------------------
# Hill-based log2FC metrics
# --------------------------------------------------------------------
def compute_log2fc_metrics(A_samps, alpha_samps, Vmax_samps, K_samps, n_samps,
                           x_true_samps, eps=1e-6):
    """
    Compute *directional* log2 fold-change metrics for the Hill-based model:

        y(x) = A + alpha * Vmax * x^n / (K^n + x^n)

    Parameters
    ----------
    A_samps, alpha_samps, Vmax_samps, K_samps, n_samps : [S, T]
        Posterior samples for each parameter.
    x_true_samps : [S, N_cells]
        Posterior samples of x_true for this cis gene.
    eps : float
        Small constant for numerical stability.

    Returns
    -------
    log2fc_full : [S, T]
        Full-range log2FC: y(x→∞) vs y(x→0), *directional w.r.t. x increasing*.
        - n > 0: log2( (A+αVmax) / A )
        - n < 0: log2( A / (A+αVmax) )
    log2fc_obs : [S, T]
        Observed-range log2FC: y(x_max_obs) vs y(x_min_obs), directional in x.
    x_min_obs, x_max_obs : float
        Observed min/max of mean x_true across cells.
    """
    # ensure arrays
    A_samps     = np.asarray(A_samps)
    alpha_samps = np.asarray(alpha_samps)
    Vmax_samps  = np.asarray(Vmax_samps)
    K_samps     = np.asarray(K_samps)
    n_samps     = np.asarray(n_samps)

    # observed x range from mean x_true per cell
    X = np.asarray(x_true_samps)              # [S, N_cells]
    x_means_per_cell = X.mean(axis=0)         # [N_cells]
    x_min_obs = float(x_means_per_cell.min())
    x_max_obs = float(x_means_per_cell.max())

    A     = A_samps
    alpha = alpha_samps
    Vmax  = Vmax_samps

    # --- FULL-RANGE --- #
    # Asymptotes: A (low) and A + α·Vmax (high), but which is at x→0 vs x→∞
    # depends on the sign of n.
    y_low_val  = A
    y_high_val = A + alpha * Vmax

    n_pos = (n_samps >= 0)  # True where curve increases with x

    # y(x→0) and y(x→∞) according to sign of n
    y_at_x0   = np.where(n_pos, y_low_val,  y_high_val)
    y_at_xinf = np.where(n_pos, y_high_val, y_low_val)

    log2fc_full = np.log2((y_at_xinf + eps) / (y_at_x0 + eps))

    # --- helper: y(x) for observed-range FC --- #
    def y_hill(x_scalar, A, alpha, Vmax, K, n, eps_inner=1e-8):
        """
        y(x) = A + alpha * Vmax * x^n / (K^n + x^n)
        evaluated at scalar x, broadcasting over [S, T] params.
        """
        x = float(x_scalar)
        x_safe = x + eps_inner
        K_safe = K + eps_inner

        with np.errstate(divide='ignore', invalid='ignore'):
            x_log = np.log(x_safe)
            K_log = np.log(K_safe)
            x_n = np.exp(n * x_log)
            K_n = np.exp(n * K_log)

        frac = x_n / (K_n + x_n + eps_inner)
        h = Vmax * frac
        return A + alpha * h

    # --- OBSERVED-RANGE --- #
    # y at min and max of empirical x range (direction is x_min→x_max)
    Y_min_obs = y_hill(x_min_obs, A, alpha, Vmax, K_samps, n_samps, eps_inner=eps)
    Y_max_obs = y_hill(x_max_obs, A, alpha, Vmax, K_samps, n_samps, eps_inner=eps)

    log2fc_obs = np.log2((Y_max_obs + eps) / (Y_min_obs + eps))

    return log2fc_full, log2fc_obs, x_min_obs, x_max_obs


def compute_log2fc_obs_for_cells(
    A_samps, alpha_samps, Vmax_samps, K_samps, n_samps,
    x_true_samps, cell_mask, guide_labels=None, eps=1e-6
):
    """
    Compute observed-range log2FC for each gene, given a subset of cells.

    - A_samps, alpha_samps, Vmax_samps, K_samps, n_samps: [S, T]
    - x_true_samps: [S, N_cells] (cis x_true samples)
    - cell_mask: boolean [N_cells], selecting cells to *use* for x_min/x_max
                 (e.g. cells of a given guide + all NTC cells)
    - guide_labels: array-like [N_cells] of guide IDs; if provided, we
                    first average x_true per guide, then take min/max
                    across guide means.
    Returns:
        log2fc_obs: [S, T] (per-sample, per-gene, directional)
        x_min_obs, x_max_obs: floats (observed range in this subset)
    """
    A_samps     = np.asarray(A_samps)
    alpha_samps = np.asarray(alpha_samps)
    Vmax_samps  = np.asarray(Vmax_samps)
    K_samps     = np.asarray(K_samps)
    n_samps     = np.asarray(n_samps)
    X           = np.asarray(x_true_samps)

    # subset cells
    X_sub = X[:, cell_mask]  # [S, N_sub]

    # per-cell means over posterior samples
    x_means_per_cell = X_sub.mean(axis=0)  # [N_sub]

    # optionally aggregate per guide first
    if guide_labels is not None:
        guide_labels = np.asarray(guide_labels)
        guides_sub = guide_labels[cell_mask]  # [N_sub]
        uniq_guides = np.unique(guides_sub)
        perguide_means = []
        for g in uniq_guides:
            perguide_means.append(x_means_per_cell[guides_sub == g].mean())
        perguide_means = np.array(perguide_means)
        x_min_obs = float(perguide_means.min())
        x_max_obs = float(perguide_means.max())
    else:
        x_min_obs = float(x_means_per_cell.min())
        x_max_obs = float(x_means_per_cell.max())

    def y_hill(x_scalar, A, alpha, Vmax, K, n, eps_inner=1e-8):
        """
        y(x) = A + alpha * Vmax * x^n / (K^n + x^n)
        """
        x = float(x_scalar)
        x_safe = x + eps_inner
        K_safe = K + eps_inner

        with np.errstate(divide='ignore', invalid='ignore'):
            x_log = np.log(x_safe)
            K_log = np.log(K_safe)
            x_n = np.exp(n * x_log)
            K_n = np.exp(n * K_log)

        frac = x_n / (K_n + x_n + eps_inner)
        h = Vmax * frac
        return A + alpha * h

    A     = A_samps
    alpha = alpha_samps
    Vmax  = Vmax_samps

    Y_min_obs = y_hill(x_min_obs, A, alpha, Vmax, K_samps, n_samps, eps_inner=eps)
    Y_max_obs = y_hill(x_max_obs, A, alpha, Vmax, K_samps, n_samps, eps_inner=eps)

    log2fc_obs = np.log2((Y_max_obs + eps) / (Y_min_obs + eps))

    return log2fc_obs, x_min_obs, x_max_obs


# --------------------------------------------------------------------
# Common scatter + heatmap plotting helper
# --------------------------------------------------------------------
def scatter_and_heatmap_edger_vs_bayes(
    df_g,
    y_col,
    base_target_color,
    base_ntc_color,
    cg,
    guide,
    ylabel,
    title_suffix,
    fc_thresh=0.5,
    flip_edger_x=True,
):
    """
    For a single guide df_g (already subset to that guide), make:
      - scatter plot: edgeR logFC vs bayesDREAM y_col
      - 3x3 heatmap of category overlap

    df_g must contain: 'logFC', 'ext_sig', 'dependent', y_col.
    """
    g_str = str(guide)

    # 4 colour classes for scatter (computed before finite mask)
    colors = []
    for dep, sig in zip(df_g['dependent'], df_g['ext_sig']):
        if (not dep) and (not sig):
            # neither method calls it
            c = base_ntc_color
        elif sig and (not dep):
            # edgeR only (FDR<0.05, not dependent in bayesDREAM)
            c = lighten(base_target_color, 0.4)
        elif dep and (not sig):
            # bayesDREAM only (dependent, FDR>=0.05)
            c = base_target_color
        else:
            # both: dependent & FDR<0.05
            c = darken(base_target_color, 0.4)
        colors.append(c)

    # restrict to finite values
    finite = np.isfinite(df_g['logFC']) & np.isfinite(df_g[y_col])
    df_plot = df_g[finite]
    if df_plot.empty:
        return
    colors = np.array(colors)[finite.values]

    # x vs y values
    x_raw = df_plot['logFC'].values
    x_vals = -x_raw if flip_edger_x else x_raw
    y_vals = df_plot[y_col].values

    # same scale on x & y
    v_min = min(x_vals.min(), y_vals.min())
    v_max = max(x_vals.max(), y_vals.max())
    pad = 0.05 * (v_max - v_min + 1e-6)
    x_lim = (v_min - pad, v_max + pad)
    y_lim = (v_min - pad, v_max + pad)

    # ---------- SCATTER ----------
    plt.figure(figsize=(5.5, 5))
    plt.scatter(
        x_vals,
        y_vals,
        s=10,
        c=colors,
        alpha=0.8,
        edgecolor='none',
    )

    # reference lines
    plt.axhline(0,    color='black', linestyle=':', linewidth=1)
    plt.axvline(0,    color='black', linestyle=':', linewidth=1)
    plt.axhline( fc_thresh,  color='black', linestyle='--', linewidth=0.8)
    plt.axhline(-fc_thresh,  color='black', linestyle='--', linewidth=0.8)
    plt.axvline( fc_thresh,  color='black', linestyle='--', linewidth=0.8)
    plt.axvline(-fc_thresh,  color='black', linestyle='--', linewidth=0.8)

    plt.xlim(x_lim)
    plt.ylim(y_lim)

    x_label_prefix = '-' if flip_edger_x else ''
    plt.xlabel(fr'{x_label_prefix}log$_2$FC (edgeR; per-guide)')
    plt.ylabel(ylabel)
    plt.title(f'{cg}, guide {g_str}: edgeR vs bayesDREAM {title_suffix}')

    legend_handles = [
        Line2D([0], [0], marker='o', linestyle='',
               color=base_ntc_color, label='neither (edgeR nor bayesDREAM)'),
        Line2D([0], [0], marker='o', linestyle='',
               color=lighten(base_target_color, 0.4),
               label='edgeR only (FDR<0.05)'),
        Line2D([0], [0], marker='o', linestyle='',
               color=base_target_color, label='bayesDREAM only (dependent)'),
        Line2D([0], [0], marker='o', linestyle='',
               color=darken(base_target_color, 0.4),
               label='both (FDR<0.05 & dependent)'),
    ]
    plt.legend(handles=legend_handles, frameon=False, fontsize=8, loc='best')

    plt.grid(True, linewidth=0.5, alpha=0.4)
    plt.tight_layout()
    plt.show()

    # ---------- HEATMAP ----------
    # edgeR: 0 = not sig; 1 = sig & |log2FC| < thresh; 2 = sig & |log2FC| ≥ thresh
    edge_cat = np.zeros(df_plot.shape[0], dtype=int)
    edge_sig = df_plot['ext_sig'].values
    edge_fc  = df_plot['logFC'].values
    edge_cat[ edge_sig & (np.abs(edge_fc) < fc_thresh) ]  = 1
    edge_cat[ edge_sig & (np.abs(edge_fc) >= fc_thresh) ] = 2

    # bayesDREAM: 0 = not dependent; 1 = dep & small; 2 = dep & large
    bayes_cat = np.zeros(df_plot.shape[0], dtype=int)
    bayes_dep = df_plot['dependent'].values
    bayes_fc  = df_plot[y_col].values
    bayes_cat[ bayes_dep & (np.abs(bayes_fc) < fc_thresh) ]  = 1
    bayes_cat[ bayes_dep & (np.abs(bayes_fc) >= fc_thresh) ] = 2

    mat = np.zeros((3, 3), dtype=int)
    for b, e in zip(bayes_cat, edge_cat):
        mat[b, e] += 1

    edge_labels = [
        "not sig",
        fr"sig, |log$_2$FC| < {fc_thresh}",
        fr"sig, |log$_2$FC| ≥ {fc_thresh}",
    ]
    bayes_labels = [
        "not dependent",
        fr"dep, |log$_2$FC| < {fc_thresh}",
        fr"dep, |log$_2$FC| ≥ {fc_thresh}",
    ]

    fig, ax_hm = plt.subplots(figsize=(5.2, 4.5))
    im = ax_hm.imshow(mat, cmap="Blues")

    total = mat.sum()
    for i in range(3):
        for j in range(3):
            count = mat[i, j]
            if total > 0:
                pct = 100.0 * count / total
                txt = f"{count}\n({pct:.1f}%)"
            else:
                txt = "0"
            ax_hm.text(
                j, i, txt,
                ha="center", va="center",
                color="black" if count < total/2 else "white",
                fontsize=8,
            )

    ax_hm.set_xticks([0, 1, 2])
    ax_hm.set_yticks([0, 1, 2])
    ax_hm.set_xticklabels(edge_labels, rotation=30, ha="right")
    ax_hm.set_yticklabels(bayes_labels)

    ax_hm.set_xlabel("edgeR category")
    ax_hm.set_ylabel("bayesDREAM category")
    ax_hm.set_title(f"{cg}, guide {g_str}:\nedgeR vs bayesDREAM categories")

    # no colourbar (values already annotated)
    plt.tight_layout()
    plt.show()


# --------------------------------------------------------------------
# Data preparation helper: shared for full & observed range
# --------------------------------------------------------------------
def prepare_de_for_cg(model, de_df, cg):
    """
    Prepare model + edgeR data for a given cis gene cg.

    Returns:
        A_samps, alpha_samps, Vmax_samps, K_samps, n_samps, x_true_samps,
        meta, de_cg (with idx, logFC, FDR, gene, n_mean, dependent, ext_sig),
        base_target_color, base_ntc_color
    """
    A_samps     = model[cg].posterior_samples_trans['A'][:, 0, :].detach().cpu().numpy()
    alpha_samps = model[cg].posterior_samples_trans['alpha'][:, 0, :].detach().cpu().numpy()
    Vmax_samps  = model[cg].posterior_samples_trans['Vmax_a'][:, 0, :].detach().cpu().numpy()
    K_samps     = model[cg].posterior_samples_trans['K_a'][:, 0, :].detach().cpu().numpy()
    n_samps     = model[cg].posterior_samples_trans['n_a'][:, 0, :].detach().cpu().numpy()
    x_true_samps = model[cg].x_true.detach().cpu().numpy()
    meta = model[cg].meta

    n_mean   = n_samps.mean(axis=0)             # [T]
    dep_mask = dependency_mask_from_n(n_samps)  # [T] bool
    T        = n_mean.shape[0]

    # gene names aligned to posterior arrays
    gene_names = np.array(model[cg].get_modality('gene').feature_meta['gene'])
    if len(gene_names) != T:
        print(f"[{cg}] WARNING: len(gene_names)={len(gene_names)} != T={T}. "
              "Trimming gene_names to first T entries.")
        gene_names = gene_names[:T]

    gene_to_idx = {g: i for i, g in enumerate(gene_names)}

    # external DE results (edgeR)
    de_cg = de_df.copy()
    de_cg = de_cg[de_cg['gene'].isin(gene_to_idx.keys())].copy()
    if de_cg.empty:
        print(f"[{cg}] No overlapping genes in edgeR results, skipping.")
        return None

    de_cg['idx'] = de_cg['gene'].map(gene_to_idx)
    de_cg = de_cg[de_cg['idx'].notna()].copy()
    de_cg['idx'] = de_cg['idx'].astype(int)
    de_cg = de_cg[de_cg['idx'] < T].copy()
    if de_cg.empty:
        print(f"[{cg}] All overlapping edgeR genes had out-of-range indices, skipping.")
        return None

    idx_vals = de_cg['idx'].values
    de_cg['n_mean']    = n_mean[idx_vals]
    de_cg['dependent'] = dep_mask[idx_vals]
    de_cg['ext_sig']   = de_cg['FDR'] < 0.05  # edgeR significance

    base_target_color = target_colors.get(cg, 'blue')
    base_ntc_color    = target_colors.get('NTC', 'grey')

    return (A_samps, alpha_samps, Vmax_samps, K_samps, n_samps,
            x_true_samps, meta, de_cg, base_target_color, base_ntc_color)


# --------------------------------------------------------------------
# 1) Full-range comparison: edgeR vs bayesDREAM log2FC_full (directional)
# --------------------------------------------------------------------
def plot_edger_vs_bayes_full_range(cg_list, model, de_df,
                                   fc_thresh=0.5, flip_edger_x=True):
    for cg in cg_list:
        print(f"\n=== {cg} (full-range) ===")

        prep = prepare_de_for_cg(model, de_df, cg)
        if prep is None:
            continue
        (A_samps, alpha_samps, Vmax_samps, K_samps, n_samps,
         x_true_samps, meta, de_cg, base_target_color, base_ntc_color) = prep

        # full-range log2FC from bayesDREAM
        log2fc_full, _, _, _ = compute_log2fc_metrics(
            A_samps, alpha_samps, Vmax_samps, K_samps, n_samps, x_true_samps
        )
        log2fc_full_mean = log2fc_full.mean(axis=0)  # [T]
        de_cg['log2fc_full'] = log2fc_full_mean[de_cg['idx'].values]

        guides_in_model = set(meta['guide'].astype(str).unique())

        for guide in sorted(de_cg['guide'].unique()):
            g_str = str(guide)
            if g_str not in guides_in_model:
                continue

            df_g = de_cg[de_cg['guide'] == guide].copy()
            if df_g.empty:
                continue

            scatter_and_heatmap_edger_vs_bayes(
                df_g=df_g,
                y_col='log2fc_full',
                base_target_color=base_target_color,
                base_ntc_color=base_ntc_color,
                cg=cg,
                guide=g_str,
                ylabel=r'log$_2$FC$_{\mathrm{full}}$ '
                       r'(bayesDREAM, $x:0\to\infty$)',
                title_suffix='(full-range log$_2$FC)',
                fc_thresh=fc_thresh,
                flip_edger_x=flip_edger_x,
            )


# --------------------------------------------------------------------
# 2) Observed-range comparison: edgeR vs bayesDREAM log2FC_obs (guide+NTC)
# --------------------------------------------------------------------
def plot_edger_vs_bayes_observed_range(cg_list, model, de_df,
                                       fc_thresh=0.5, flip_edger_x=True,
                                       aggregate_by_guide=True):
    for cg in cg_list:
        print(f"\n=== {cg} (observed-range) ===")

        prep = prepare_de_for_cg(model, de_df, cg)
        if prep is None:
            continue
        (A_samps, alpha_samps, Vmax_samps, K_samps, n_samps,
         x_true_samps, meta, de_cg, base_target_color, base_ntc_color) = prep

        guides_in_model = set(meta['guide'].astype(str).unique())
        guide_labels_all = meta['guide'].astype(str).values

        for guide in sorted(de_cg['guide'].unique()):
            g_str = str(guide)
            if g_str not in guides_in_model:
                continue

            df_g = de_cg[de_cg['guide'] == guide].copy()
            if df_g.empty:
                continue

            # cells: this guide + NTC-target cells
            guide_mask = guide_labels_all == g_str
            ntc_mask   = meta['target'].astype(str).str.upper().str.contains('NTC').to_numpy()
            cell_mask  = guide_mask | ntc_mask

            log2fc_obs_guide, x_min_obs, x_max_obs = compute_log2fc_obs_for_cells(
                A_samps, alpha_samps, Vmax_samps, K_samps, n_samps,
                x_true_samps,
                cell_mask=cell_mask,
                guide_labels=guide_labels_all if aggregate_by_guide else None,
            )
            log2fc_obs_mean_guide = log2fc_obs_guide.mean(axis=0)
            de_cg.loc[df_g.index, 'log2fc_obs_guide'] = \
                log2fc_obs_mean_guide[df_g['idx'].values]

            df_g = de_cg.loc[df_g.index].copy()

            scatter_and_heatmap_edger_vs_bayes(
                df_g=df_g,
                y_col='log2fc_obs_guide',
                base_target_color=base_target_color,
                base_ntc_color=base_ntc_color,
                cg=cg,
                guide=g_str,
                ylabel=r'log$_2$FC$_{\mathrm{obs}}$ '
                       r'(bayesDREAM; guide+NTC x-range)',
                title_suffix='(observed log$_2$FC)',
                fc_thresh=fc_thresh,
                flip_edger_x=flip_edger_x,
            )

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.lines import Line2D

def compare_edger_extreme_per_target(model, de_df, cg, log2fc_thresh=0.5):
    """
    For a cis gene `cg`, compare:
      - x: log2FC from edgeR (guide with largest |log2FC| per gene)
      - y: mean full-range log2FC from bayesDREAM (directional)
    Returns:
      dict with the 9 heatmap categories, each a list of gene names.
    """

    # ======= bayesDREAM side =======
    A_samps      = model[cg].posterior_samples_trans['A'][:, 0, :].detach().cpu().numpy()
    alpha_samps  = model[cg].posterior_samples_trans['alpha'][:, 0, :].detach().cpu().numpy()
    Vmax_samps   = model[cg].posterior_samples_trans['Vmax_a'][:, 0, :].detach().cpu().numpy()
    K_samps      = model[cg].posterior_samples_trans['K_a'][:, 0, :].detach().cpu().numpy()
    n_samps      = model[cg].posterior_samples_trans['n_a'][:, 0, :].detach().cpu().numpy()
    x_true_samps = model[cg].x_true.detach().cpu().numpy()
    meta         = model[cg].meta

    # directional full-range FC
    log2fc_full, _, _, _ = compute_log2fc_metrics(
        A_samps, alpha_samps, Vmax_samps, K_samps, n_samps, x_true_samps
    )
    log2fc_full_mean = log2fc_full.mean(axis=0)
    dep_mask = dependency_mask_from_n(n_samps)
    T = log2fc_full_mean.shape[0]

    # map genes
    gene_names = np.array(model[cg].get_modality('gene').feature_meta['gene'])
    if len(gene_names) != T:
        print(f"[{cg}] WARNING: trimming gene list from {len(gene_names)} to {T}")
        gene_names = gene_names[:T]
    gene_to_idx = {g: i for i, g in enumerate(gene_names)}

    # ======= edgeR side =======
    guides = meta['guide'].astype(str).values
    targets = meta['target'].astype(str).values
    target_guides = {g for g, t in zip(guides, targets) if t.upper() == cg.upper()}

    de_cg = de_df[
        de_df['guide'].astype(str).isin(target_guides)
        & de_df['gene'].isin(gene_to_idx.keys())
    ].copy()
    if de_cg.empty:
        print(f"[{cg}] No overlapping edgeR rows for guides targeting {cg}")
        return {}

    # pick guide with largest |logFC|
    idx_max = de_cg.groupby('gene')['logFC'].apply(lambda s: np.argmax(np.abs(s.values)))
    extreme_rows = []
    for gene, idx_local in idx_max.items():
        sub = de_cg[de_cg['gene'] == gene]
        if len(sub) > idx_local:
            extreme_rows.append(sub.iloc[idx_local])
    de_extreme = pd.DataFrame(extreme_rows)

    # any guide significant?
    sig_any = (
        de_cg.groupby('gene')['FDR']
        .apply(lambda s: np.any(s.values < 0.05))
        .reset_index()
        .rename(columns={'FDR': 'any_sig'})
    )
    agg = de_extreme.merge(sig_any, on='gene', how='left')

    # map back to model
    agg['idx'] = agg['gene'].map(gene_to_idx)
    agg = agg[agg['idx'].notna()].copy()
    agg['idx'] = agg['idx'].astype(int)
    agg = agg[agg['idx'] < T].copy()

    agg['log2fc_full'] = log2fc_full_mean[agg['idx'].values]
    agg['dependent']   = dep_mask[agg['idx'].values]

    # ======= extract data =======
    x = - agg['logFC'].values
    y = agg['log2fc_full'].values
    dep = agg['dependent'].values
    sig = agg['any_sig'].values
    genes = agg['gene'].values

    finite = np.isfinite(x) & np.isfinite(y)
    x, y, dep, sig, genes = x[finite], y[finite], dep[finite], sig[finite], genes[finite]

    base_target_color = target_colors.get(cg, 'blue')
    base_ntc_color = target_colors.get('NTC', cm.Greys(0.6))

    def lighten(c, amount=0.3):
        arr = np.array(cm.colors.to_rgba(c))
        return tuple((1 - amount) * arr + amount * np.array([1, 1, 1, 1]))

    def darken(c, amount=0.3):
        arr = np.array(cm.colors.to_rgba(c))
        return tuple((1 - amount) * arr + amount * np.array([0, 0, 0, 1]))

    # ======= plot: scatter =======
    colors = []
    for d, sgn in zip(dep, sig):
        if (not d) and (not sgn):
            c = base_ntc_color
        elif sgn and (not d):
            c = lighten(base_target_color, 0.4)
        elif d and (not sgn):
            c = base_target_color
        else:
            c = darken(base_target_color, 0.4)
        colors.append(c)

    fig, ax = plt.subplots(figsize=(5.5, 5.5))
    ax.set_box_aspect(1)
    ax.scatter(x, y, s=10, c=colors, alpha=0.8, edgecolor='none')
    for t in [0, log2fc_thresh, -log2fc_thresh]:
        ax.axhline(t, color='black', linestyle='--' if t != 0 else ':', linewidth=0.8)
        ax.axvline(t, color='black', linestyle='--' if t != 0 else ':', linewidth=0.8)
    pad = 0.1 * max(np.ptp(x), np.ptp(y))
    ax.set_xlim(x.min() - pad, x.max() + pad)
    ax.set_ylim(y.min() - pad, y.max() + pad)
    ax.set_xlabel(r'-log$_2$FC (edgeR; guide with largest |log$_2$FC|)')
    ax.set_ylabel(r'mean full-range log$_2$FC (bayesDREAM)')
    ax.set_title(f"{cg}: edgeR (extreme guide) vs bayesDREAM log$_2$FC$_{{full}}$")
    legend_handles = [
        Line2D([0], [0], marker='o', linestyle='', color=base_ntc_color, label='neither'),
        Line2D([0], [0], marker='o', linestyle='', color=lighten(base_target_color, 0.4), label='edgeR only'),
        Line2D([0], [0], marker='o', linestyle='', color=base_target_color, label='bayesDREAM only'),
        Line2D([0], [0], marker='o', linestyle='', color=darken(base_target_color, 0.4), label='both'),
    ]
    leg = ax.legend(handles=legend_handles, frameon=False, loc="center left",
                    bbox_to_anchor=(1.02, 0.5), borderaxespad=0.0, fontsize=8)
    for lh in leg.legend_handles:
        try: lh.set_sizes([50])
        except Exception: pass
    fig.subplots_adjust(right=0.78)
    ax.grid(True, linewidth=0.5, alpha=0.4)
    plt.show()

    # ======= confusion categories =======
    edge_cat = np.zeros(x.shape[0], dtype=int)
    edge_cat[sig & (np.abs(x) <  log2fc_thresh)] = 1
    edge_cat[sig & (np.abs(x) >= log2fc_thresh)] = 2
    bayes_cat = np.zeros(y.shape[0], dtype=int)
    bayes_cat[dep & (np.abs(y) <  log2fc_thresh)] = 1
    bayes_cat[dep & (np.abs(y) >= log2fc_thresh)] = 2

    mat = np.zeros((3, 3), dtype=int)
    cat_dict = {}
    for b in range(3):
        for e in range(3):
            mask = (bayes_cat == b) & (edge_cat == e)
            mat[b, e] = np.sum(mask)
            key = f"bayes{b}_edge{e}"
            cat_dict[key] = genes[mask].tolist()

    # readable key mapping
    readable = {
        "bayes0_edge0": "notdep_notsig",
        "bayes0_edge1": "notdep_sigsmall",
        "bayes0_edge2": "notdep_siglarge",
        "bayes1_edge0": "depsmall_notsig",
        "bayes1_edge1": "depsmall_sigsmall",
        "bayes1_edge2": "depsmall_siglarge",
        "bayes2_edge0": "deplarge_notsig",
        "bayes2_edge1": "deplarge_sigsmall",
        "bayes2_edge2": "deplarge_siglarge",
    }
    gene_comp = {readable[k]: cat_dict[k] for k in readable}

    # ======= heatmap =======
    edge_labels = ["not sig", f"sig < {log2fc_thresh}", f"sig ≥ {log2fc_thresh}"]
    bayes_labels = ["not dep", f"dep < {log2fc_thresh}", f"dep ≥ {log2fc_thresh}"]
    fig_hm, ax_hm = plt.subplots(figsize=(5.2, 4.5))
    im = ax_hm.imshow(mat, cmap="Blues")
    total = mat.sum()
    for i in range(3):
        for j in range(3):
            count = mat[i, j]
            pct = 100.0 * count / total if total > 0 else 0
            ax_hm.text(j, i, f"{count}\n({pct:.1f}%)", ha="center", va="center",
                       color="black" if count < total/2 else "white", fontsize=8)
    ax_hm.set_xticks([0, 1, 2])
    ax_hm.set_yticks([0, 1, 2])
    ax_hm.set_xticklabels(edge_labels, rotation=30, ha="right")
    ax_hm.set_yticklabels(bayes_labels)
    ax_hm.set_xlabel("edgeR category")
    ax_hm.set_ylabel("bayesDREAM category")
    ax_hm.set_title(f"{cg}: edgeR vs bayesDREAM categories")
    plt.tight_layout()
    plt.show()

    return gene_comp
time: 14.6 ms (started: 2025-11-03 21:15:33 +01:00)
In [23]:
# --------------------------------------------------------------------
# Load edgeR results and call the functions
# --------------------------------------------------------------------
de_df = pd.read_csv(
    "/cfs/klemming/projects/snic/lappalainen_lab1/users/Leah/"
    "data/unpublished/Josie_pilot/processed/fromJosie_Oct2025/de_results_per_guide.csv"
)

# target_colors must exist already, e.g. your earlier map using cm.Greens, etc.
# Example call:
cis_genes = ['GFI1B', 'GEMIN5', 'DDX6']

# Full-range comparison
plot_edger_vs_bayes_full_range(cis_genes, model, de_df,
                               fc_thresh=0.5, flip_edger_x=True)

# Observed-range (guide+NTC) comparison
plot_edger_vs_bayes_observed_range(cis_genes, model, de_df,
                                   fc_thresh=0.5, flip_edger_x=True,
                                   aggregate_by_guide=True)
=== GFI1B (full-range) ===
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
=== GEMIN5 (full-range) ===
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
=== DDX6 (full-range) ===
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
=== GFI1B (observed-range) ===
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
=== GEMIN5 (observed-range) ===
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
=== DDX6 (observed-range) ===
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
time: 18.3 s (started: 2025-11-03 21:15:33 +01:00)
In [24]:
# === run and collect ===
gene_comp = {}
for cg in ['GFI1B', 'GEMIN5', 'DDX6']:
    gene_comp[cg] = compare_edger_extreme_per_target(model, de_df, cg)
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
time: 1min 24s (started: 2025-11-03 21:15:52 +01:00)
In [25]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm

def volcano_edger_per_guide(
    model,
    de_df,
    cg,
    fdr_thr=0.05,
    logfc_thr=0.5,
    n_top_label=10,
):
    """
    Volcano plots per guide (edgeR), highlighting genes that are:
      - not dependent in bayesDREAM
      - FDR < fdr_thr
      - |log2FC| > logfc_thr

    Top n_top_label genes per guide (smallest FDR) are annotated.
    """

    # ---- bayesDREAM dependency per gene for this cis gene ----
    n_samps = model[cg].posterior_samples_trans["n_a"][:, 0, :].detach().cpu().numpy()
    dep_mask = dependency_mask_from_n(n_samps)           # [T] bool
    T = dep_mask.shape[0]

    gene_names = np.array(model[cg].get_modality('gene').feature_meta['gene'])
    if len(gene_names) != T:
        print(f"[{cg}] WARNING: len(gene_names)={len(gene_names)} != T={T}; trimming.")
        gene_names = gene_names[:T]
    gene_to_idx = {g: i for i, g in enumerate(gene_names)}

    # guides for this model
    guides_in_model = set(model[cg].meta["guide"].astype(str).unique())

    base_target_color = target_colors.get(cg, "tab:blue")
    highlight_color   = base_target_color
    background_color  = "lightgrey"

    # ---- loop over each guide present in edgeR & this model ----
    for guide in sorted(de_df["guide"].astype(str).unique()):
        g_str = str(guide)
        if g_str not in guides_in_model:
            continue

        de_g = de_df[de_df["guide"].astype(str) == g_str].copy()
        if de_g.empty:
            continue

        # map genes to bayesDREAM indices (where available)
        de_g["idx"] = de_g["gene"].map(gene_to_idx)

        # start as float column full of NaN
        de_g["dependent"] = np.nan
        mask_in_model = de_g["idx"].notna()
        if mask_in_model.any():
            idx_local = de_g.loc[mask_in_model, "idx"].astype(int).values
            de_g.loc[mask_in_model, "dependent"] = dep_mask[idx_local].astype(float)

        # numeric columns
        logFC = de_g["logFC"].to_numpy(dtype=float)
        FDR   = de_g["FDR"].to_numpy(dtype=float)

        # handle FDR=0 for -log10
        FDR_safe = np.where(FDR <= 0, 1e-18, FDR)
        neg_log10_FDR = -np.log10(FDR_safe)
        de_g["neg_log10_FDR"] = neg_log10_FDR

        # ---- highlight mask: non-dependent & sig & big effect ----
        dep = de_g["dependent"].to_numpy(dtype=float)   # float with NaNs
        dep_known = ~np.isnan(dep)
        not_dep   = dep_known & (dep == 0.0)
        sig       = FDR < fdr_thr
        big       = np.abs(logFC) > logfc_thr

        highlight = not_dep & sig & big
        de_g["highlight"] = highlight

        # ---- Volcano plot ----
        fig, ax = plt.subplots(figsize=(5.5, 5.5))

        # reference lines
        ax.axhline(-np.log10(fdr_thr), color="black", linestyle="--", linewidth=0.8, alpha=0.5)
        ax.axhline(-np.log10(1e-18),  color="black", linestyle="--", linewidth=0.8, alpha=0.5)
        ax.axvline( logfc_thr,  color="black", linestyle="--", linewidth=0.8, alpha=0.5)
        ax.axvline(-logfc_thr,  color="black", linestyle="--", linewidth=0.8, alpha=0.5)

        # background: all points
        ax.scatter(
            logFC,
            neg_log10_FDR,
            s=10,
            alpha=0.8,
            color=background_color,
            edgecolor="none",
        )
        
        # highlighted points
        if highlight.any():
            ax.scatter(
                logFC[highlight],
                neg_log10_FDR[highlight],
                s=10,
                alpha=0.8,
                color=highlight_color,
                edgecolor="none",
                linewidth=0.3,
            )

        ax.set_xlabel(r"log$_2$FC (edgeR)")
        ax.set_ylabel(r"$- \log_{10}(\mathrm{FDR})$")
        ax.set_title(
            f"{cg}, guide {g_str}: edgeR volcano\n"
            f"(non-dependent & sig & |log2FC|>{logfc_thr} highlighted)"
        )

        # label top n_top_label genes by smallest FDR
        de_g_sorted = de_g.sort_values("FDR", ascending=True)
        to_label = de_g_sorted.iloc[:n_top_label].copy()
        for _, row in to_label.iterrows():
            x0 = float(row["logFC"])
            y0 = float(row["neg_log10_FDR"])
            gene_name = row["gene"]
            ax.text(
                x0,
                y0,
                gene_name,
                fontsize=7,
                ha="center",
                va="bottom",
                alpha=0.9,
            )

        # no legend
        ax.grid(True, linewidth=0.5, alpha=0.4)
        plt.tight_layout()
        plt.show()
time: 3.82 ms (started: 2025-11-03 21:17:16 +01:00)

assess problematic genes¶

In [26]:
for cg in ["GFI1B", "GEMIN5", "DDX6"]:
    volcano_edger_per_guide(model, de_df, cg, fdr_thr=0.05, logfc_thr=0.5, n_top_label=0)
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
time: 2.46 s (started: 2025-11-03 21:17:16 +01:00)
In [50]:
problematic_genes = de_df[(de_df['guide'].str.contains('GFI1B')) & \
    (de_df['gene'].isin(gene_comp['GFI1B']['notdep_siglarge'])) &\
    (de_df['FDR']<0.05) &\
    ((de_df['logFC']).abs()>0.5)
]
plt.hist(problematic_genes.value_counts('gene'))
plt.ylabel('number of genes')
plt.xlabel('number of guides in which they are significant')
plt.show()
triple_problems = problematic_genes.value_counts('gene').index[problematic_genes.value_counts('gene') == 3]
print(f'Genes problematic in 3 guides: {triple_problems}')
No description has been provided for this image
Genes problematic in 3 guides: Index(['AREG', 'CLC', 'CRYBG2', 'CCL3', 'MT1G', 'STC1', 'MMP1'], dtype='object', name='gene')
time: 110 ms (started: 2025-11-03 21:24:05 +01:00)
In [53]:
problematic_genes[problematic_genes['gene'].isin(triple_problems)]
Out[53]:
logFC logCPM F PValue FDR gene guide
28238 -1.838350 5.991872 18.909756 1.393749e-05 2.934746e-04 STC1 GFI1B_1
28416 -1.799965 5.802117 14.828829 1.195865e-04 1.979997e-03 CCL3 GFI1B_1
28435 -1.795806 5.914566 14.526167 1.408513e-04 2.280073e-03 CLC GFI1B_1
28499 -1.489763 5.778868 13.742414 2.135175e-04 3.214884e-03 MT1G GFI1B_1
28955 -0.892110 5.792663 9.541958 2.021793e-03 2.032403e-02 CRYBG2 GFI1B_1
29148 -1.049437 5.842847 8.339231 3.899421e-03 3.436474e-02 AREG GFI1B_1
29304 -1.033280 5.782471 7.655913 5.714823e-03 4.579828e-02 MMP1 GFI1B_1
41381 3.522624 5.914566 373.628880 0.000000e+00 0.000000e+00 CLC GFI1B_2
42041 -1.943067 5.802117 24.822402 6.555345e-07 1.357527e-05 CCL3 GFI1B_2
42404 -1.289393 5.842847 16.812098 4.204179e-05 5.634990e-04 AREG GFI1B_2
42483 -1.385856 5.778868 15.758315 7.368590e-05 9.165598e-04 MT1G GFI1B_2
43111 -1.033280 5.782471 10.694745 1.094091e-03 8.692227e-03 MMP1 GFI1B_2
43471 -0.921560 5.991872 8.781050 3.056014e-03 2.010904e-02 STC1 GFI1B_2
43637 -0.687874 5.792663 8.173919 4.271267e-03 2.604302e-02 CRYBG2 GFI1B_2
55195 -1.938152 5.914566 47.425777 6.858900e-12 3.378506e-09 CLC GFI1B_3
55197 -1.891080 5.802117 44.746268 2.557740e-11 1.175879e-08 CCL3 GFI1B_3
55282 -1.201129 5.778868 22.627513 2.060848e-06 2.471584e-04 MT1G GFI1B_3
55358 -0.981557 5.782471 17.414753 3.143810e-05 2.270127e-03 MMP1 GFI1B_3
55436 -0.660349 5.792663 14.509890 1.414780e-04 7.253773e-03 CRYBG2 GFI1B_3
55610 -0.703603 5.991872 10.591538 1.142519e-03 3.557025e-02 STC1 GFI1B_3
55651 -0.683465 5.842847 10.228842 1.392780e-03 3.963096e-02 AREG GFI1B_3
time: 7.59 ms (started: 2025-11-03 21:33:10 +01:00)
In [27]:
problematic_genes = de_df[(de_df['guide'].str.contains('GFI1B')) & \
    (de_df['gene'].isin(gene_comp['GFI1B']['notdep_siglarge'])) &\
    (de_df['FDR']==0)
]
problematic_genes
Out[27]:
logFC logCPM F PValue FDR gene guide
27588 4.123429 5.680410 329.657290 0.0 0.0 IL1B GFI1B_1
41381 3.522624 5.914566 373.628880 0.0 0.0 CLC GFI1B_2
41385 2.922150 7.001393 279.075077 0.0 0.0 PRSS2 GFI1B_2
41386 3.651268 5.883702 296.803656 0.0 0.0 CCL4L2 GFI1B_2
41466 2.184632 5.656408 100.945521 0.0 0.0 PLEK GFI1B_2
41482 1.992146 6.092199 86.716443 0.0 0.0 CCL3L3 GFI1B_2
55169 -3.216667 6.684942 97.740010 0.0 0.0 REN GFI1B_3
time: 47.5 ms (started: 2025-11-03 21:17:18 +01:00)
In [28]:
cgs = ['GFI1B', 'GEMIN5', 'DDX6']
for cg in cgs:
    model[cg].set_technical_groups(['Sample'])
[INFO] Set technical_group_code with 1 groups based on ['Sample']
[INFO] Set technical_group_code with 1 groups based on ['Sample']
[INFO] Set technical_group_code with 1 groups based on ['Sample']
time: 4.76 ms (started: 2025-11-03 21:17:18 +01:00)
In [31]:
for tg in problematic_genes['gene'].values:
    model['GFI1B'].plot_xy_data(
        tg,
        window=100,
        sum_factor_col='sum_factor_new',
        show_correction='uncorrected'
    )
    
    model['GFI1B'].plot_xy_data(
        tg,
        window=100,
        sum_factor_col='clustered.sum.factor',
        show_correction='uncorrected'
    )
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
time: 3.29 s (started: 2025-11-03 21:18:50 +01:00)
In [51]:
for tg in triple_problems:
    model['GFI1B'].plot_xy_data(
        tg,
        window=100,
        sum_factor_col='sum_factor_new',
        show_correction='uncorrected'
    )
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
time: 1.59 s (started: 2025-11-03 21:28:50 +01:00)
In [ ]: